1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true
 219        }]
 220    );
 221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 222        "Response to Message 1".into(),
 223    ));
 224    fake_model.end_last_completion_stream();
 225    cx.run_until_parked();
 226
 227    // Send another user message and verify only the latest is cached
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Message 2"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let completion = fake_model.pending_completions().pop().unwrap();
 236    assert_eq!(
 237        completion.messages[1..],
 238        vec![
 239            LanguageModelRequestMessage {
 240                role: Role::User,
 241                content: vec!["Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::Assistant,
 246                content: vec!["Response to Message 1".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::User,
 251                content: vec!["Message 2".into()],
 252                cache: true
 253            }
 254        ]
 255    );
 256    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 257        "Response to Message 2".into(),
 258    ));
 259    fake_model.end_last_completion_stream();
 260    cx.run_until_parked();
 261
 262    // Simulate a tool call and verify that the latest tool result is cached
 263    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 264    thread
 265        .update(cx, |thread, cx| {
 266            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 267        })
 268        .unwrap();
 269    cx.run_until_parked();
 270
 271    let tool_use = LanguageModelToolUse {
 272        id: "tool_1".into(),
 273        name: EchoTool::name().into(),
 274        raw_input: json!({"text": "test"}).to_string(),
 275        input: json!({"text": "test"}),
 276        is_input_complete: true,
 277    };
 278    fake_model
 279        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 280    fake_model.end_last_completion_stream();
 281    cx.run_until_parked();
 282
 283    let completion = fake_model.pending_completions().pop().unwrap();
 284    let tool_result = LanguageModelToolResult {
 285        tool_use_id: "tool_1".into(),
 286        tool_name: EchoTool::name().into(),
 287        is_error: false,
 288        content: "test".into(),
 289        output: Some("test".into()),
 290    };
 291    assert_eq!(
 292        completion.messages[1..],
 293        vec![
 294            LanguageModelRequestMessage {
 295                role: Role::User,
 296                content: vec!["Message 1".into()],
 297                cache: false
 298            },
 299            LanguageModelRequestMessage {
 300                role: Role::Assistant,
 301                content: vec!["Response to Message 1".into()],
 302                cache: false
 303            },
 304            LanguageModelRequestMessage {
 305                role: Role::User,
 306                content: vec!["Message 2".into()],
 307                cache: false
 308            },
 309            LanguageModelRequestMessage {
 310                role: Role::Assistant,
 311                content: vec!["Response to Message 2".into()],
 312                cache: false
 313            },
 314            LanguageModelRequestMessage {
 315                role: Role::User,
 316                content: vec!["Use the echo tool".into()],
 317                cache: false
 318            },
 319            LanguageModelRequestMessage {
 320                role: Role::Assistant,
 321                content: vec![MessageContent::ToolUse(tool_use)],
 322                cache: false
 323            },
 324            LanguageModelRequestMessage {
 325                role: Role::User,
 326                content: vec![MessageContent::ToolResult(tool_result)],
 327                cache: true
 328            }
 329        ]
 330    );
 331}
 332
 333#[gpui::test]
 334#[cfg_attr(not(feature = "e2e"), ignore)]
 335async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(EchoTool);
 342            thread.send(
 343                UserMessageId::new(),
 344                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 345                cx,
 346            )
 347        })
 348        .unwrap()
 349        .collect()
 350        .await;
 351    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 352
 353    // Test a tool calls that's likely to complete *after* streaming stops.
 354    let events = thread
 355        .update(cx, |thread, cx| {
 356            thread.remove_tool(&EchoTool::name());
 357            thread.add_tool(DelayTool);
 358            thread.send(
 359                UserMessageId::new(),
 360                [
 361                    "Now call the delay tool with 200ms.",
 362                    "When the timer goes off, then you echo the output of the tool.",
 363                ],
 364                cx,
 365            )
 366        })
 367        .unwrap()
 368        .collect()
 369        .await;
 370    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 371    thread.update(cx, |thread, _cx| {
 372        assert!(
 373            thread
 374                .last_message()
 375                .unwrap()
 376                .as_agent_message()
 377                .unwrap()
 378                .content
 379                .iter()
 380                .any(|content| {
 381                    if let AgentMessageContent::Text(text) = content {
 382                        text.contains("Ding")
 383                    } else {
 384                        false
 385                    }
 386                }),
 387            "{}",
 388            thread.to_markdown()
 389        );
 390    });
 391}
 392
 393#[gpui::test]
 394#[cfg_attr(not(feature = "e2e"), ignore)]
 395async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 396    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 397
 398    // Test a tool call that's likely to complete *before* streaming stops.
 399    let mut events = thread
 400        .update(cx, |thread, cx| {
 401            thread.add_tool(WordListTool);
 402            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 403        })
 404        .unwrap();
 405
 406    let mut saw_partial_tool_use = false;
 407    while let Some(event) = events.next().await {
 408        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 409            thread.update(cx, |thread, _cx| {
 410                // Look for a tool use in the thread's last message
 411                let message = thread.last_message().unwrap();
 412                let agent_message = message.as_agent_message().unwrap();
 413                let last_content = agent_message.content.last().unwrap();
 414                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 415                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 416                    if tool_call.status == acp::ToolCallStatus::Pending {
 417                        if !last_tool_use.is_input_complete
 418                            && last_tool_use.input.get("g").is_none()
 419                        {
 420                            saw_partial_tool_use = true;
 421                        }
 422                    } else {
 423                        last_tool_use
 424                            .input
 425                            .get("a")
 426                            .expect("'a' has streamed because input is now complete");
 427                        last_tool_use
 428                            .input
 429                            .get("g")
 430                            .expect("'g' has streamed because input is now complete");
 431                    }
 432                } else {
 433                    panic!("last content should be a tool use");
 434                }
 435            });
 436        }
 437    }
 438
 439    assert!(
 440        saw_partial_tool_use,
 441        "should see at least one partially streamed tool use in the history"
 442    );
 443}
 444
 445#[gpui::test]
 446async fn test_tool_authorization(cx: &mut TestAppContext) {
 447    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 448    let fake_model = model.as_fake();
 449
 450    let mut events = thread
 451        .update(cx, |thread, cx| {
 452            thread.add_tool(ToolRequiringPermission);
 453            thread.send(UserMessageId::new(), ["abc"], cx)
 454        })
 455        .unwrap();
 456    cx.run_until_parked();
 457    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 458        LanguageModelToolUse {
 459            id: "tool_id_1".into(),
 460            name: ToolRequiringPermission::name().into(),
 461            raw_input: "{}".into(),
 462            input: json!({}),
 463            is_input_complete: true,
 464        },
 465    ));
 466    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 467        LanguageModelToolUse {
 468            id: "tool_id_2".into(),
 469            name: ToolRequiringPermission::name().into(),
 470            raw_input: "{}".into(),
 471            input: json!({}),
 472            is_input_complete: true,
 473        },
 474    ));
 475    fake_model.end_last_completion_stream();
 476    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 477    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 478
 479    // Approve the first
 480    tool_call_auth_1
 481        .response
 482        .send(tool_call_auth_1.options[1].id.clone())
 483        .unwrap();
 484    cx.run_until_parked();
 485
 486    // Reject the second
 487    tool_call_auth_2
 488        .response
 489        .send(tool_call_auth_1.options[2].id.clone())
 490        .unwrap();
 491    cx.run_until_parked();
 492
 493    let completion = fake_model.pending_completions().pop().unwrap();
 494    let message = completion.messages.last().unwrap();
 495    assert_eq!(
 496        message.content,
 497        vec![
 498            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 499                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 500                tool_name: ToolRequiringPermission::name().into(),
 501                is_error: false,
 502                content: "Allowed".into(),
 503                output: Some("Allowed".into())
 504            }),
 505            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 506                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 507                tool_name: ToolRequiringPermission::name().into(),
 508                is_error: true,
 509                content: "Permission to run tool denied by user".into(),
 510                output: Some("Permission to run tool denied by user".into())
 511            })
 512        ]
 513    );
 514
 515    // Simulate yet another tool call.
 516    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 517        LanguageModelToolUse {
 518            id: "tool_id_3".into(),
 519            name: ToolRequiringPermission::name().into(),
 520            raw_input: "{}".into(),
 521            input: json!({}),
 522            is_input_complete: true,
 523        },
 524    ));
 525    fake_model.end_last_completion_stream();
 526
 527    // Respond by always allowing tools.
 528    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 529    tool_call_auth_3
 530        .response
 531        .send(tool_call_auth_3.options[0].id.clone())
 532        .unwrap();
 533    cx.run_until_parked();
 534    let completion = fake_model.pending_completions().pop().unwrap();
 535    let message = completion.messages.last().unwrap();
 536    assert_eq!(
 537        message.content,
 538        vec![language_model::MessageContent::ToolResult(
 539            LanguageModelToolResult {
 540                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 541                tool_name: ToolRequiringPermission::name().into(),
 542                is_error: false,
 543                content: "Allowed".into(),
 544                output: Some("Allowed".into())
 545            }
 546        )]
 547    );
 548
 549    // Simulate a final tool call, ensuring we don't trigger authorization.
 550    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 551        LanguageModelToolUse {
 552            id: "tool_id_4".into(),
 553            name: ToolRequiringPermission::name().into(),
 554            raw_input: "{}".into(),
 555            input: json!({}),
 556            is_input_complete: true,
 557        },
 558    ));
 559    fake_model.end_last_completion_stream();
 560    cx.run_until_parked();
 561    let completion = fake_model.pending_completions().pop().unwrap();
 562    let message = completion.messages.last().unwrap();
 563    assert_eq!(
 564        message.content,
 565        vec![language_model::MessageContent::ToolResult(
 566            LanguageModelToolResult {
 567                tool_use_id: "tool_id_4".into(),
 568                tool_name: ToolRequiringPermission::name().into(),
 569                is_error: false,
 570                content: "Allowed".into(),
 571                output: Some("Allowed".into())
 572            }
 573        )]
 574    );
 575}
 576
 577#[gpui::test]
 578async fn test_tool_hallucination(cx: &mut TestAppContext) {
 579    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 580    let fake_model = model.as_fake();
 581
 582    let mut events = thread
 583        .update(cx, |thread, cx| {
 584            thread.send(UserMessageId::new(), ["abc"], cx)
 585        })
 586        .unwrap();
 587    cx.run_until_parked();
 588    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 589        LanguageModelToolUse {
 590            id: "tool_id_1".into(),
 591            name: "nonexistent_tool".into(),
 592            raw_input: "{}".into(),
 593            input: json!({}),
 594            is_input_complete: true,
 595        },
 596    ));
 597    fake_model.end_last_completion_stream();
 598
 599    let tool_call = expect_tool_call(&mut events).await;
 600    assert_eq!(tool_call.title, "nonexistent_tool");
 601    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 602    let update = expect_tool_call_update_fields(&mut events).await;
 603    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 604}
 605
 606#[gpui::test]
 607async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 608    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 609    let fake_model = model.as_fake();
 610
 611    let events = thread
 612        .update(cx, |thread, cx| {
 613            thread.add_tool(EchoTool);
 614            thread.send(UserMessageId::new(), ["abc"], cx)
 615        })
 616        .unwrap();
 617    cx.run_until_parked();
 618    let tool_use = LanguageModelToolUse {
 619        id: "tool_id_1".into(),
 620        name: EchoTool::name().into(),
 621        raw_input: "{}".into(),
 622        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 623        is_input_complete: true,
 624    };
 625    fake_model
 626        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 627    fake_model.end_last_completion_stream();
 628
 629    cx.run_until_parked();
 630    let completion = fake_model.pending_completions().pop().unwrap();
 631    let tool_result = LanguageModelToolResult {
 632        tool_use_id: "tool_id_1".into(),
 633        tool_name: EchoTool::name().into(),
 634        is_error: false,
 635        content: "def".into(),
 636        output: Some("def".into()),
 637    };
 638    assert_eq!(
 639        completion.messages[1..],
 640        vec![
 641            LanguageModelRequestMessage {
 642                role: Role::User,
 643                content: vec!["abc".into()],
 644                cache: false
 645            },
 646            LanguageModelRequestMessage {
 647                role: Role::Assistant,
 648                content: vec![MessageContent::ToolUse(tool_use.clone())],
 649                cache: false
 650            },
 651            LanguageModelRequestMessage {
 652                role: Role::User,
 653                content: vec![MessageContent::ToolResult(tool_result.clone())],
 654                cache: true
 655            },
 656        ]
 657    );
 658
 659    // Simulate reaching tool use limit.
 660    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 661        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 662    ));
 663    fake_model.end_last_completion_stream();
 664    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 665    assert!(
 666        last_event
 667            .unwrap_err()
 668            .is::<language_model::ToolUseLimitReachedError>()
 669    );
 670
 671    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 672    cx.run_until_parked();
 673    let completion = fake_model.pending_completions().pop().unwrap();
 674    assert_eq!(
 675        completion.messages[1..],
 676        vec![
 677            LanguageModelRequestMessage {
 678                role: Role::User,
 679                content: vec!["abc".into()],
 680                cache: false
 681            },
 682            LanguageModelRequestMessage {
 683                role: Role::Assistant,
 684                content: vec![MessageContent::ToolUse(tool_use)],
 685                cache: false
 686            },
 687            LanguageModelRequestMessage {
 688                role: Role::User,
 689                content: vec![MessageContent::ToolResult(tool_result)],
 690                cache: false
 691            },
 692            LanguageModelRequestMessage {
 693                role: Role::User,
 694                content: vec!["Continue where you left off".into()],
 695                cache: true
 696            }
 697        ]
 698    );
 699
 700    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 701    fake_model.end_last_completion_stream();
 702    events.collect::<Vec<_>>().await;
 703    thread.read_with(cx, |thread, _cx| {
 704        assert_eq!(
 705            thread.last_message().unwrap().to_markdown(),
 706            indoc! {"
 707                ## Assistant
 708
 709                Done
 710            "}
 711        )
 712    });
 713}
 714
 715#[gpui::test]
 716async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 717    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 718    let fake_model = model.as_fake();
 719
 720    let events = thread
 721        .update(cx, |thread, cx| {
 722            thread.add_tool(EchoTool);
 723            thread.send(UserMessageId::new(), ["abc"], cx)
 724        })
 725        .unwrap();
 726    cx.run_until_parked();
 727
 728    let tool_use = LanguageModelToolUse {
 729        id: "tool_id_1".into(),
 730        name: EchoTool::name().into(),
 731        raw_input: "{}".into(),
 732        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 733        is_input_complete: true,
 734    };
 735    let tool_result = LanguageModelToolResult {
 736        tool_use_id: "tool_id_1".into(),
 737        tool_name: EchoTool::name().into(),
 738        is_error: false,
 739        content: "def".into(),
 740        output: Some("def".into()),
 741    };
 742    fake_model
 743        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 744    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 745        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 746    ));
 747    fake_model.end_last_completion_stream();
 748    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 749    assert!(
 750        last_event
 751            .unwrap_err()
 752            .is::<language_model::ToolUseLimitReachedError>()
 753    );
 754
 755    thread
 756        .update(cx, |thread, cx| {
 757            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 758        })
 759        .unwrap();
 760    cx.run_until_parked();
 761    let completion = fake_model.pending_completions().pop().unwrap();
 762    assert_eq!(
 763        completion.messages[1..],
 764        vec![
 765            LanguageModelRequestMessage {
 766                role: Role::User,
 767                content: vec!["abc".into()],
 768                cache: false
 769            },
 770            LanguageModelRequestMessage {
 771                role: Role::Assistant,
 772                content: vec![MessageContent::ToolUse(tool_use)],
 773                cache: false
 774            },
 775            LanguageModelRequestMessage {
 776                role: Role::User,
 777                content: vec![MessageContent::ToolResult(tool_result)],
 778                cache: false
 779            },
 780            LanguageModelRequestMessage {
 781                role: Role::User,
 782                content: vec!["ghi".into()],
 783                cache: true
 784            }
 785        ]
 786    );
 787}
 788
 789async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 790    let event = events
 791        .next()
 792        .await
 793        .expect("no tool call authorization event received")
 794        .unwrap();
 795    match event {
 796        ThreadEvent::ToolCall(tool_call) => tool_call,
 797        event => {
 798            panic!("Unexpected event {event:?}");
 799        }
 800    }
 801}
 802
 803async fn expect_tool_call_update_fields(
 804    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 805) -> acp::ToolCallUpdate {
 806    let event = events
 807        .next()
 808        .await
 809        .expect("no tool call authorization event received")
 810        .unwrap();
 811    match event {
 812        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 813        event => {
 814            panic!("Unexpected event {event:?}");
 815        }
 816    }
 817}
 818
 819async fn next_tool_call_authorization(
 820    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 821) -> ToolCallAuthorization {
 822    loop {
 823        let event = events
 824            .next()
 825            .await
 826            .expect("no tool call authorization event received")
 827            .unwrap();
 828        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 829            let permission_kinds = tool_call_authorization
 830                .options
 831                .iter()
 832                .map(|o| o.kind)
 833                .collect::<Vec<_>>();
 834            assert_eq!(
 835                permission_kinds,
 836                vec![
 837                    acp::PermissionOptionKind::AllowAlways,
 838                    acp::PermissionOptionKind::AllowOnce,
 839                    acp::PermissionOptionKind::RejectOnce,
 840                ]
 841            );
 842            return tool_call_authorization;
 843        }
 844    }
 845}
 846
 847#[gpui::test]
 848#[cfg_attr(not(feature = "e2e"), ignore)]
 849async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 850    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 851
 852    // Test concurrent tool calls with different delay times
 853    let events = thread
 854        .update(cx, |thread, cx| {
 855            thread.add_tool(DelayTool);
 856            thread.send(
 857                UserMessageId::new(),
 858                [
 859                    "Call the delay tool twice in the same message.",
 860                    "Once with 100ms. Once with 300ms.",
 861                    "When both timers are complete, describe the outputs.",
 862                ],
 863                cx,
 864            )
 865        })
 866        .unwrap()
 867        .collect()
 868        .await;
 869
 870    let stop_reasons = stop_events(events);
 871    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 872
 873    thread.update(cx, |thread, _cx| {
 874        let last_message = thread.last_message().unwrap();
 875        let agent_message = last_message.as_agent_message().unwrap();
 876        let text = agent_message
 877            .content
 878            .iter()
 879            .filter_map(|content| {
 880                if let AgentMessageContent::Text(text) = content {
 881                    Some(text.as_str())
 882                } else {
 883                    None
 884                }
 885            })
 886            .collect::<String>();
 887
 888        assert!(text.contains("Ding"));
 889    });
 890}
 891
 892#[gpui::test]
 893async fn test_profiles(cx: &mut TestAppContext) {
 894    let ThreadTest {
 895        model, thread, fs, ..
 896    } = setup(cx, TestModel::Fake).await;
 897    let fake_model = model.as_fake();
 898
 899    thread.update(cx, |thread, _cx| {
 900        thread.add_tool(DelayTool);
 901        thread.add_tool(EchoTool);
 902        thread.add_tool(InfiniteTool);
 903    });
 904
 905    // Override profiles and wait for settings to be loaded.
 906    fs.insert_file(
 907        paths::settings_file(),
 908        json!({
 909            "agent": {
 910                "profiles": {
 911                    "test-1": {
 912                        "name": "Test Profile 1",
 913                        "tools": {
 914                            EchoTool::name(): true,
 915                            DelayTool::name(): true,
 916                        }
 917                    },
 918                    "test-2": {
 919                        "name": "Test Profile 2",
 920                        "tools": {
 921                            InfiniteTool::name(): true,
 922                        }
 923                    }
 924                }
 925            }
 926        })
 927        .to_string()
 928        .into_bytes(),
 929    )
 930    .await;
 931    cx.run_until_parked();
 932
 933    // Test that test-1 profile (default) has echo and delay tools
 934    thread
 935        .update(cx, |thread, cx| {
 936            thread.set_profile(AgentProfileId("test-1".into()));
 937            thread.send(UserMessageId::new(), ["test"], cx)
 938        })
 939        .unwrap();
 940    cx.run_until_parked();
 941
 942    let mut pending_completions = fake_model.pending_completions();
 943    assert_eq!(pending_completions.len(), 1);
 944    let completion = pending_completions.pop().unwrap();
 945    let tool_names: Vec<String> = completion
 946        .tools
 947        .iter()
 948        .map(|tool| tool.name.clone())
 949        .collect();
 950    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 951    fake_model.end_last_completion_stream();
 952
 953    // Switch to test-2 profile, and verify that it has only the infinite tool.
 954    thread
 955        .update(cx, |thread, cx| {
 956            thread.set_profile(AgentProfileId("test-2".into()));
 957            thread.send(UserMessageId::new(), ["test2"], cx)
 958        })
 959        .unwrap();
 960    cx.run_until_parked();
 961    let mut pending_completions = fake_model.pending_completions();
 962    assert_eq!(pending_completions.len(), 1);
 963    let completion = pending_completions.pop().unwrap();
 964    let tool_names: Vec<String> = completion
 965        .tools
 966        .iter()
 967        .map(|tool| tool.name.clone())
 968        .collect();
 969    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 970}
 971
 972#[gpui::test]
 973async fn test_mcp_tools(cx: &mut TestAppContext) {
 974    let ThreadTest {
 975        model,
 976        thread,
 977        context_server_store,
 978        fs,
 979        ..
 980    } = setup(cx, TestModel::Fake).await;
 981    let fake_model = model.as_fake();
 982
 983    // Override profiles and wait for settings to be loaded.
 984    fs.insert_file(
 985        paths::settings_file(),
 986        json!({
 987            "agent": {
 988                "always_allow_tool_actions": true,
 989                "profiles": {
 990                    "test": {
 991                        "name": "Test Profile",
 992                        "enable_all_context_servers": true,
 993                        "tools": {
 994                            EchoTool::name(): true,
 995                        }
 996                    },
 997                }
 998            }
 999        })
1000        .to_string()
1001        .into_bytes(),
1002    )
1003    .await;
1004    cx.run_until_parked();
1005    thread.update(cx, |thread, _| {
1006        thread.set_profile(AgentProfileId("test".into()))
1007    });
1008
1009    let mut mcp_tool_calls = setup_context_server(
1010        "test_server",
1011        vec![context_server::types::Tool {
1012            name: "echo".into(),
1013            description: None,
1014            input_schema: serde_json::to_value(EchoTool::input_schema(
1015                LanguageModelToolSchemaFormat::JsonSchema,
1016            ))
1017            .unwrap(),
1018            output_schema: None,
1019            annotations: None,
1020        }],
1021        &context_server_store,
1022        cx,
1023    );
1024
1025    let events = thread.update(cx, |thread, cx| {
1026        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1027    });
1028    cx.run_until_parked();
1029
1030    // Simulate the model calling the MCP tool.
1031    let completion = fake_model.pending_completions().pop().unwrap();
1032    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1033    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1034        LanguageModelToolUse {
1035            id: "tool_1".into(),
1036            name: "echo".into(),
1037            raw_input: json!({"text": "test"}).to_string(),
1038            input: json!({"text": "test"}),
1039            is_input_complete: true,
1040        },
1041    ));
1042    fake_model.end_last_completion_stream();
1043    cx.run_until_parked();
1044
1045    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1046    assert_eq!(tool_call_params.name, "echo");
1047    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1048    tool_call_response
1049        .send(context_server::types::CallToolResponse {
1050            content: vec![context_server::types::ToolResponseContent::Text {
1051                text: "test".into(),
1052            }],
1053            is_error: None,
1054            meta: None,
1055            structured_content: None,
1056        })
1057        .unwrap();
1058    cx.run_until_parked();
1059
1060    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1061    fake_model.send_last_completion_stream_text_chunk("Done!");
1062    fake_model.end_last_completion_stream();
1063    events.collect::<Vec<_>>().await;
1064
1065    // Send again after adding the echo tool, ensuring the name collision is resolved.
1066    let events = thread.update(cx, |thread, cx| {
1067        thread.add_tool(EchoTool);
1068        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1069    });
1070    cx.run_until_parked();
1071    let completion = fake_model.pending_completions().pop().unwrap();
1072    assert_eq!(
1073        tool_names_for_completion(&completion),
1074        vec!["echo", "test_server_echo"]
1075    );
1076    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1077        LanguageModelToolUse {
1078            id: "tool_2".into(),
1079            name: "test_server_echo".into(),
1080            raw_input: json!({"text": "mcp"}).to_string(),
1081            input: json!({"text": "mcp"}),
1082            is_input_complete: true,
1083        },
1084    ));
1085    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086        LanguageModelToolUse {
1087            id: "tool_3".into(),
1088            name: "echo".into(),
1089            raw_input: json!({"text": "native"}).to_string(),
1090            input: json!({"text": "native"}),
1091            is_input_complete: true,
1092        },
1093    ));
1094    fake_model.end_last_completion_stream();
1095    cx.run_until_parked();
1096
1097    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1098    assert_eq!(tool_call_params.name, "echo");
1099    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1100    tool_call_response
1101        .send(context_server::types::CallToolResponse {
1102            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1103            is_error: None,
1104            meta: None,
1105            structured_content: None,
1106        })
1107        .unwrap();
1108    cx.run_until_parked();
1109
1110    // Ensure the tool results were inserted with the correct names.
1111    let completion = fake_model.pending_completions().pop().unwrap();
1112    assert_eq!(
1113        completion.messages.last().unwrap().content,
1114        vec![
1115            MessageContent::ToolResult(LanguageModelToolResult {
1116                tool_use_id: "tool_3".into(),
1117                tool_name: "echo".into(),
1118                is_error: false,
1119                content: "native".into(),
1120                output: Some("native".into()),
1121            },),
1122            MessageContent::ToolResult(LanguageModelToolResult {
1123                tool_use_id: "tool_2".into(),
1124                tool_name: "test_server_echo".into(),
1125                is_error: false,
1126                content: "mcp".into(),
1127                output: Some("mcp".into()),
1128            },),
1129        ]
1130    );
1131    fake_model.end_last_completion_stream();
1132    events.collect::<Vec<_>>().await;
1133}
1134
1135#[gpui::test]
1136async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1137    let ThreadTest {
1138        model,
1139        thread,
1140        context_server_store,
1141        fs,
1142        ..
1143    } = setup(cx, TestModel::Fake).await;
1144    let fake_model = model.as_fake();
1145
1146    // Set up a profile with all tools enabled
1147    fs.insert_file(
1148        paths::settings_file(),
1149        json!({
1150            "agent": {
1151                "profiles": {
1152                    "test": {
1153                        "name": "Test Profile",
1154                        "enable_all_context_servers": true,
1155                        "tools": {
1156                            EchoTool::name(): true,
1157                            DelayTool::name(): true,
1158                            WordListTool::name(): true,
1159                            ToolRequiringPermission::name(): true,
1160                            InfiniteTool::name(): true,
1161                        }
1162                    },
1163                }
1164            }
1165        })
1166        .to_string()
1167        .into_bytes(),
1168    )
1169    .await;
1170    cx.run_until_parked();
1171
1172    thread.update(cx, |thread, _| {
1173        thread.set_profile(AgentProfileId("test".into()));
1174        thread.add_tool(EchoTool);
1175        thread.add_tool(DelayTool);
1176        thread.add_tool(WordListTool);
1177        thread.add_tool(ToolRequiringPermission);
1178        thread.add_tool(InfiniteTool);
1179    });
1180
1181    // Set up multiple context servers with some overlapping tool names
1182    let _server1_calls = setup_context_server(
1183        "xxx",
1184        vec![
1185            context_server::types::Tool {
1186                name: "echo".into(), // Conflicts with native EchoTool
1187                description: None,
1188                input_schema: serde_json::to_value(EchoTool::input_schema(
1189                    LanguageModelToolSchemaFormat::JsonSchema,
1190                ))
1191                .unwrap(),
1192                output_schema: None,
1193                annotations: None,
1194            },
1195            context_server::types::Tool {
1196                name: "unique_tool_1".into(),
1197                description: None,
1198                input_schema: json!({"type": "object", "properties": {}}),
1199                output_schema: None,
1200                annotations: None,
1201            },
1202        ],
1203        &context_server_store,
1204        cx,
1205    );
1206
1207    let _server2_calls = setup_context_server(
1208        "yyy",
1209        vec![
1210            context_server::types::Tool {
1211                name: "echo".into(), // Also conflicts with native EchoTool
1212                description: None,
1213                input_schema: serde_json::to_value(EchoTool::input_schema(
1214                    LanguageModelToolSchemaFormat::JsonSchema,
1215                ))
1216                .unwrap(),
1217                output_schema: None,
1218                annotations: None,
1219            },
1220            context_server::types::Tool {
1221                name: "unique_tool_2".into(),
1222                description: None,
1223                input_schema: json!({"type": "object", "properties": {}}),
1224                output_schema: None,
1225                annotations: None,
1226            },
1227            context_server::types::Tool {
1228                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1229                description: None,
1230                input_schema: json!({"type": "object", "properties": {}}),
1231                output_schema: None,
1232                annotations: None,
1233            },
1234            context_server::types::Tool {
1235                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1236                description: None,
1237                input_schema: json!({"type": "object", "properties": {}}),
1238                output_schema: None,
1239                annotations: None,
1240            },
1241        ],
1242        &context_server_store,
1243        cx,
1244    );
1245    let _server3_calls = setup_context_server(
1246        "zzz",
1247        vec![
1248            context_server::types::Tool {
1249                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1250                description: None,
1251                input_schema: json!({"type": "object", "properties": {}}),
1252                output_schema: None,
1253                annotations: None,
1254            },
1255            context_server::types::Tool {
1256                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1257                description: None,
1258                input_schema: json!({"type": "object", "properties": {}}),
1259                output_schema: None,
1260                annotations: None,
1261            },
1262            context_server::types::Tool {
1263                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1264                description: None,
1265                input_schema: json!({"type": "object", "properties": {}}),
1266                output_schema: None,
1267                annotations: None,
1268            },
1269        ],
1270        &context_server_store,
1271        cx,
1272    );
1273
1274    thread
1275        .update(cx, |thread, cx| {
1276            thread.send(UserMessageId::new(), ["Go"], cx)
1277        })
1278        .unwrap();
1279    cx.run_until_parked();
1280    let completion = fake_model.pending_completions().pop().unwrap();
1281    assert_eq!(
1282        tool_names_for_completion(&completion),
1283        vec![
1284            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1285            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1286            "delay",
1287            "echo",
1288            "infinite",
1289            "tool_requiring_permission",
1290            "unique_tool_1",
1291            "unique_tool_2",
1292            "word_list",
1293            "xxx_echo",
1294            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1295            "yyy_echo",
1296            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1297        ]
1298    );
1299}
1300
1301#[gpui::test]
1302#[cfg_attr(not(feature = "e2e"), ignore)]
1303async fn test_cancellation(cx: &mut TestAppContext) {
1304    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1305
1306    let mut events = thread
1307        .update(cx, |thread, cx| {
1308            thread.add_tool(InfiniteTool);
1309            thread.add_tool(EchoTool);
1310            thread.send(
1311                UserMessageId::new(),
1312                ["Call the echo tool, then call the infinite tool, then explain their output"],
1313                cx,
1314            )
1315        })
1316        .unwrap();
1317
1318    // Wait until both tools are called.
1319    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1320    let mut echo_id = None;
1321    let mut echo_completed = false;
1322    while let Some(event) = events.next().await {
1323        match event.unwrap() {
1324            ThreadEvent::ToolCall(tool_call) => {
1325                assert_eq!(tool_call.title, expected_tools.remove(0));
1326                if tool_call.title == "Echo" {
1327                    echo_id = Some(tool_call.id);
1328                }
1329            }
1330            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1331                acp::ToolCallUpdate {
1332                    id,
1333                    fields:
1334                        acp::ToolCallUpdateFields {
1335                            status: Some(acp::ToolCallStatus::Completed),
1336                            ..
1337                        },
1338                    meta: None,
1339                },
1340            )) if Some(&id) == echo_id.as_ref() => {
1341                echo_completed = true;
1342            }
1343            _ => {}
1344        }
1345
1346        if expected_tools.is_empty() && echo_completed {
1347            break;
1348        }
1349    }
1350
1351    // Cancel the current send and ensure that the event stream is closed, even
1352    // if one of the tools is still running.
1353    thread.update(cx, |thread, cx| thread.cancel(cx));
1354    let events = events.collect::<Vec<_>>().await;
1355    let last_event = events.last();
1356    assert!(
1357        matches!(
1358            last_event,
1359            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1360        ),
1361        "unexpected event {last_event:?}"
1362    );
1363
1364    // Ensure we can still send a new message after cancellation.
1365    let events = thread
1366        .update(cx, |thread, cx| {
1367            thread.send(
1368                UserMessageId::new(),
1369                ["Testing: reply with 'Hello' then stop."],
1370                cx,
1371            )
1372        })
1373        .unwrap()
1374        .collect::<Vec<_>>()
1375        .await;
1376    thread.update(cx, |thread, _cx| {
1377        let message = thread.last_message().unwrap();
1378        let agent_message = message.as_agent_message().unwrap();
1379        assert_eq!(
1380            agent_message.content,
1381            vec![AgentMessageContent::Text("Hello".to_string())]
1382        );
1383    });
1384    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1385}
1386
1387#[gpui::test]
1388async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1389    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1390    let fake_model = model.as_fake();
1391
1392    let events_1 = thread
1393        .update(cx, |thread, cx| {
1394            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1395        })
1396        .unwrap();
1397    cx.run_until_parked();
1398    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1399    cx.run_until_parked();
1400
1401    let events_2 = thread
1402        .update(cx, |thread, cx| {
1403            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1404        })
1405        .unwrap();
1406    cx.run_until_parked();
1407    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1408    fake_model
1409        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1410    fake_model.end_last_completion_stream();
1411
1412    let events_1 = events_1.collect::<Vec<_>>().await;
1413    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1414    let events_2 = events_2.collect::<Vec<_>>().await;
1415    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1416}
1417
1418#[gpui::test]
1419async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1420    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1421    let fake_model = model.as_fake();
1422
1423    let events_1 = thread
1424        .update(cx, |thread, cx| {
1425            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1426        })
1427        .unwrap();
1428    cx.run_until_parked();
1429    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1430    fake_model
1431        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1432    fake_model.end_last_completion_stream();
1433    let events_1 = events_1.collect::<Vec<_>>().await;
1434
1435    let events_2 = thread
1436        .update(cx, |thread, cx| {
1437            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1438        })
1439        .unwrap();
1440    cx.run_until_parked();
1441    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1442    fake_model
1443        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1444    fake_model.end_last_completion_stream();
1445    let events_2 = events_2.collect::<Vec<_>>().await;
1446
1447    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1448    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1449}
1450
1451#[gpui::test]
1452async fn test_refusal(cx: &mut TestAppContext) {
1453    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1454    let fake_model = model.as_fake();
1455
1456    let events = thread
1457        .update(cx, |thread, cx| {
1458            thread.send(UserMessageId::new(), ["Hello"], cx)
1459        })
1460        .unwrap();
1461    cx.run_until_parked();
1462    thread.read_with(cx, |thread, _| {
1463        assert_eq!(
1464            thread.to_markdown(),
1465            indoc! {"
1466                ## User
1467
1468                Hello
1469            "}
1470        );
1471    });
1472
1473    fake_model.send_last_completion_stream_text_chunk("Hey!");
1474    cx.run_until_parked();
1475    thread.read_with(cx, |thread, _| {
1476        assert_eq!(
1477            thread.to_markdown(),
1478            indoc! {"
1479                ## User
1480
1481                Hello
1482
1483                ## Assistant
1484
1485                Hey!
1486            "}
1487        );
1488    });
1489
1490    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1491    fake_model
1492        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1493    let events = events.collect::<Vec<_>>().await;
1494    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1495    thread.read_with(cx, |thread, _| {
1496        assert_eq!(thread.to_markdown(), "");
1497    });
1498}
1499
1500#[gpui::test]
1501async fn test_truncate_first_message(cx: &mut TestAppContext) {
1502    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1503    let fake_model = model.as_fake();
1504
1505    let message_id = UserMessageId::new();
1506    thread
1507        .update(cx, |thread, cx| {
1508            thread.send(message_id.clone(), ["Hello"], cx)
1509        })
1510        .unwrap();
1511    cx.run_until_parked();
1512    thread.read_with(cx, |thread, _| {
1513        assert_eq!(
1514            thread.to_markdown(),
1515            indoc! {"
1516                ## User
1517
1518                Hello
1519            "}
1520        );
1521        assert_eq!(thread.latest_token_usage(), None);
1522    });
1523
1524    fake_model.send_last_completion_stream_text_chunk("Hey!");
1525    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1526        language_model::TokenUsage {
1527            input_tokens: 32_000,
1528            output_tokens: 16_000,
1529            cache_creation_input_tokens: 0,
1530            cache_read_input_tokens: 0,
1531        },
1532    ));
1533    cx.run_until_parked();
1534    thread.read_with(cx, |thread, _| {
1535        assert_eq!(
1536            thread.to_markdown(),
1537            indoc! {"
1538                ## User
1539
1540                Hello
1541
1542                ## Assistant
1543
1544                Hey!
1545            "}
1546        );
1547        assert_eq!(
1548            thread.latest_token_usage(),
1549            Some(acp_thread::TokenUsage {
1550                used_tokens: 32_000 + 16_000,
1551                max_tokens: 1_000_000,
1552            })
1553        );
1554    });
1555
1556    thread
1557        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1558        .unwrap();
1559    cx.run_until_parked();
1560    thread.read_with(cx, |thread, _| {
1561        assert_eq!(thread.to_markdown(), "");
1562        assert_eq!(thread.latest_token_usage(), None);
1563    });
1564
1565    // Ensure we can still send a new message after truncation.
1566    thread
1567        .update(cx, |thread, cx| {
1568            thread.send(UserMessageId::new(), ["Hi"], cx)
1569        })
1570        .unwrap();
1571    thread.update(cx, |thread, _cx| {
1572        assert_eq!(
1573            thread.to_markdown(),
1574            indoc! {"
1575                ## User
1576
1577                Hi
1578            "}
1579        );
1580    });
1581    cx.run_until_parked();
1582    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1583    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1584        language_model::TokenUsage {
1585            input_tokens: 40_000,
1586            output_tokens: 20_000,
1587            cache_creation_input_tokens: 0,
1588            cache_read_input_tokens: 0,
1589        },
1590    ));
1591    cx.run_until_parked();
1592    thread.read_with(cx, |thread, _| {
1593        assert_eq!(
1594            thread.to_markdown(),
1595            indoc! {"
1596                ## User
1597
1598                Hi
1599
1600                ## Assistant
1601
1602                Ahoy!
1603            "}
1604        );
1605
1606        assert_eq!(
1607            thread.latest_token_usage(),
1608            Some(acp_thread::TokenUsage {
1609                used_tokens: 40_000 + 20_000,
1610                max_tokens: 1_000_000,
1611            })
1612        );
1613    });
1614}
1615
1616#[gpui::test]
1617async fn test_truncate_second_message(cx: &mut TestAppContext) {
1618    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1619    let fake_model = model.as_fake();
1620
1621    thread
1622        .update(cx, |thread, cx| {
1623            thread.send(UserMessageId::new(), ["Message 1"], cx)
1624        })
1625        .unwrap();
1626    cx.run_until_parked();
1627    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1628    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1629        language_model::TokenUsage {
1630            input_tokens: 32_000,
1631            output_tokens: 16_000,
1632            cache_creation_input_tokens: 0,
1633            cache_read_input_tokens: 0,
1634        },
1635    ));
1636    fake_model.end_last_completion_stream();
1637    cx.run_until_parked();
1638
1639    let assert_first_message_state = |cx: &mut TestAppContext| {
1640        thread.clone().read_with(cx, |thread, _| {
1641            assert_eq!(
1642                thread.to_markdown(),
1643                indoc! {"
1644                    ## User
1645
1646                    Message 1
1647
1648                    ## Assistant
1649
1650                    Message 1 response
1651                "}
1652            );
1653
1654            assert_eq!(
1655                thread.latest_token_usage(),
1656                Some(acp_thread::TokenUsage {
1657                    used_tokens: 32_000 + 16_000,
1658                    max_tokens: 1_000_000,
1659                })
1660            );
1661        });
1662    };
1663
1664    assert_first_message_state(cx);
1665
1666    let second_message_id = UserMessageId::new();
1667    thread
1668        .update(cx, |thread, cx| {
1669            thread.send(second_message_id.clone(), ["Message 2"], cx)
1670        })
1671        .unwrap();
1672    cx.run_until_parked();
1673
1674    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1675    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1676        language_model::TokenUsage {
1677            input_tokens: 40_000,
1678            output_tokens: 20_000,
1679            cache_creation_input_tokens: 0,
1680            cache_read_input_tokens: 0,
1681        },
1682    ));
1683    fake_model.end_last_completion_stream();
1684    cx.run_until_parked();
1685
1686    thread.read_with(cx, |thread, _| {
1687        assert_eq!(
1688            thread.to_markdown(),
1689            indoc! {"
1690                ## User
1691
1692                Message 1
1693
1694                ## Assistant
1695
1696                Message 1 response
1697
1698                ## User
1699
1700                Message 2
1701
1702                ## Assistant
1703
1704                Message 2 response
1705            "}
1706        );
1707
1708        assert_eq!(
1709            thread.latest_token_usage(),
1710            Some(acp_thread::TokenUsage {
1711                used_tokens: 40_000 + 20_000,
1712                max_tokens: 1_000_000,
1713            })
1714        );
1715    });
1716
1717    thread
1718        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1719        .unwrap();
1720    cx.run_until_parked();
1721
1722    assert_first_message_state(cx);
1723}
1724
1725#[gpui::test]
1726async fn test_title_generation(cx: &mut TestAppContext) {
1727    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1728    let fake_model = model.as_fake();
1729
1730    let summary_model = Arc::new(FakeLanguageModel::default());
1731    thread.update(cx, |thread, cx| {
1732        thread.set_summarization_model(Some(summary_model.clone()), cx)
1733    });
1734
1735    let send = thread
1736        .update(cx, |thread, cx| {
1737            thread.send(UserMessageId::new(), ["Hello"], cx)
1738        })
1739        .unwrap();
1740    cx.run_until_parked();
1741
1742    fake_model.send_last_completion_stream_text_chunk("Hey!");
1743    fake_model.end_last_completion_stream();
1744    cx.run_until_parked();
1745    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1746
1747    // Ensure the summary model has been invoked to generate a title.
1748    summary_model.send_last_completion_stream_text_chunk("Hello ");
1749    summary_model.send_last_completion_stream_text_chunk("world\nG");
1750    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1751    summary_model.end_last_completion_stream();
1752    send.collect::<Vec<_>>().await;
1753    cx.run_until_parked();
1754    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1755
1756    // Send another message, ensuring no title is generated this time.
1757    let send = thread
1758        .update(cx, |thread, cx| {
1759            thread.send(UserMessageId::new(), ["Hello again"], cx)
1760        })
1761        .unwrap();
1762    cx.run_until_parked();
1763    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1764    fake_model.end_last_completion_stream();
1765    cx.run_until_parked();
1766    assert_eq!(summary_model.pending_completions(), Vec::new());
1767    send.collect::<Vec<_>>().await;
1768    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1769}
1770
1771#[gpui::test]
1772async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1773    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1774    let fake_model = model.as_fake();
1775
1776    let _events = thread
1777        .update(cx, |thread, cx| {
1778            thread.add_tool(ToolRequiringPermission);
1779            thread.add_tool(EchoTool);
1780            thread.send(UserMessageId::new(), ["Hey!"], cx)
1781        })
1782        .unwrap();
1783    cx.run_until_parked();
1784
1785    let permission_tool_use = LanguageModelToolUse {
1786        id: "tool_id_1".into(),
1787        name: ToolRequiringPermission::name().into(),
1788        raw_input: "{}".into(),
1789        input: json!({}),
1790        is_input_complete: true,
1791    };
1792    let echo_tool_use = LanguageModelToolUse {
1793        id: "tool_id_2".into(),
1794        name: EchoTool::name().into(),
1795        raw_input: json!({"text": "test"}).to_string(),
1796        input: json!({"text": "test"}),
1797        is_input_complete: true,
1798    };
1799    fake_model.send_last_completion_stream_text_chunk("Hi!");
1800    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1801        permission_tool_use,
1802    ));
1803    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1804        echo_tool_use.clone(),
1805    ));
1806    fake_model.end_last_completion_stream();
1807    cx.run_until_parked();
1808
1809    // Ensure pending tools are skipped when building a request.
1810    let request = thread
1811        .read_with(cx, |thread, cx| {
1812            thread.build_completion_request(CompletionIntent::EditFile, cx)
1813        })
1814        .unwrap();
1815    assert_eq!(
1816        request.messages[1..],
1817        vec![
1818            LanguageModelRequestMessage {
1819                role: Role::User,
1820                content: vec!["Hey!".into()],
1821                cache: true
1822            },
1823            LanguageModelRequestMessage {
1824                role: Role::Assistant,
1825                content: vec![
1826                    MessageContent::Text("Hi!".into()),
1827                    MessageContent::ToolUse(echo_tool_use.clone())
1828                ],
1829                cache: false
1830            },
1831            LanguageModelRequestMessage {
1832                role: Role::User,
1833                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1834                    tool_use_id: echo_tool_use.id.clone(),
1835                    tool_name: echo_tool_use.name,
1836                    is_error: false,
1837                    content: "test".into(),
1838                    output: Some("test".into())
1839                })],
1840                cache: false
1841            },
1842        ],
1843    );
1844}
1845
1846#[gpui::test]
1847async fn test_agent_connection(cx: &mut TestAppContext) {
1848    cx.update(settings::init);
1849    let templates = Templates::new();
1850
1851    // Initialize language model system with test provider
1852    cx.update(|cx| {
1853        gpui_tokio::init(cx);
1854        client::init_settings(cx);
1855
1856        let http_client = FakeHttpClient::with_404_response();
1857        let clock = Arc::new(clock::FakeSystemClock::new());
1858        let client = Client::new(clock, http_client, cx);
1859        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1860        language_model::init(client.clone(), cx);
1861        language_models::init(user_store, client.clone(), cx);
1862        Project::init_settings(cx);
1863        LanguageModelRegistry::test(cx);
1864        agent_settings::init(cx);
1865    });
1866    cx.executor().forbid_parking();
1867
1868    // Create a project for new_thread
1869    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1870    fake_fs.insert_tree(path!("/test"), json!({})).await;
1871    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1872    let cwd = Path::new("/test");
1873    let text_thread_store =
1874        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1875    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1876
1877    // Create agent and connection
1878    let agent = NativeAgent::new(
1879        project.clone(),
1880        history_store,
1881        templates.clone(),
1882        None,
1883        fake_fs.clone(),
1884        &mut cx.to_async(),
1885    )
1886    .await
1887    .unwrap();
1888    let connection = NativeAgentConnection(agent.clone());
1889
1890    // Create a thread using new_thread
1891    let connection_rc = Rc::new(connection.clone());
1892    let acp_thread = cx
1893        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1894        .await
1895        .expect("new_thread should succeed");
1896
1897    // Get the session_id from the AcpThread
1898    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1899
1900    // Test model_selector returns Some
1901    let selector_opt = connection.model_selector(&session_id);
1902    assert!(
1903        selector_opt.is_some(),
1904        "agent should always support ModelSelector"
1905    );
1906    let selector = selector_opt.unwrap();
1907
1908    // Test list_models
1909    let listed_models = cx
1910        .update(|cx| selector.list_models(cx))
1911        .await
1912        .expect("list_models should succeed");
1913    let AgentModelList::Grouped(listed_models) = listed_models else {
1914        panic!("Unexpected model list type");
1915    };
1916    assert!(!listed_models.is_empty(), "should have at least one model");
1917    assert_eq!(
1918        listed_models[&AgentModelGroupName("Fake".into())][0]
1919            .id
1920            .0
1921            .as_ref(),
1922        "fake/fake"
1923    );
1924
1925    // Test selected_model returns the default
1926    let model = cx
1927        .update(|cx| selector.selected_model(cx))
1928        .await
1929        .expect("selected_model should succeed");
1930    let model = cx
1931        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1932        .unwrap();
1933    let model = model.as_fake();
1934    assert_eq!(model.id().0, "fake", "should return default model");
1935
1936    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1937    cx.run_until_parked();
1938    model.send_last_completion_stream_text_chunk("def");
1939    cx.run_until_parked();
1940    acp_thread.read_with(cx, |thread, cx| {
1941        assert_eq!(
1942            thread.to_markdown(cx),
1943            indoc! {"
1944                ## User
1945
1946                abc
1947
1948                ## Assistant
1949
1950                def
1951
1952            "}
1953        )
1954    });
1955
1956    // Test cancel
1957    cx.update(|cx| connection.cancel(&session_id, cx));
1958    request.await.expect("prompt should fail gracefully");
1959
1960    // Ensure that dropping the ACP thread causes the native thread to be
1961    // dropped as well.
1962    cx.update(|_| drop(acp_thread));
1963    let result = cx
1964        .update(|cx| {
1965            connection.prompt(
1966                Some(acp_thread::UserMessageId::new()),
1967                acp::PromptRequest {
1968                    session_id: session_id.clone(),
1969                    prompt: vec!["ghi".into()],
1970                    meta: None,
1971                },
1972                cx,
1973            )
1974        })
1975        .await;
1976    assert_eq!(
1977        result.as_ref().unwrap_err().to_string(),
1978        "Session not found",
1979        "unexpected result: {:?}",
1980        result
1981    );
1982}
1983
1984#[gpui::test]
1985async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1986    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1987    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1988    let fake_model = model.as_fake();
1989
1990    let mut events = thread
1991        .update(cx, |thread, cx| {
1992            thread.send(UserMessageId::new(), ["Think"], cx)
1993        })
1994        .unwrap();
1995    cx.run_until_parked();
1996
1997    // Simulate streaming partial input.
1998    let input = json!({});
1999    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2000        LanguageModelToolUse {
2001            id: "1".into(),
2002            name: ThinkingTool::name().into(),
2003            raw_input: input.to_string(),
2004            input,
2005            is_input_complete: false,
2006        },
2007    ));
2008
2009    // Input streaming completed
2010    let input = json!({ "content": "Thinking hard!" });
2011    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2012        LanguageModelToolUse {
2013            id: "1".into(),
2014            name: "thinking".into(),
2015            raw_input: input.to_string(),
2016            input,
2017            is_input_complete: true,
2018        },
2019    ));
2020    fake_model.end_last_completion_stream();
2021    cx.run_until_parked();
2022
2023    let tool_call = expect_tool_call(&mut events).await;
2024    assert_eq!(
2025        tool_call,
2026        acp::ToolCall {
2027            id: acp::ToolCallId("1".into()),
2028            title: "Thinking".into(),
2029            kind: acp::ToolKind::Think,
2030            status: acp::ToolCallStatus::Pending,
2031            content: vec![],
2032            locations: vec![],
2033            raw_input: Some(json!({})),
2034            raw_output: None,
2035            meta: Some(json!({ "tool_name": "thinking" })),
2036        }
2037    );
2038    let update = expect_tool_call_update_fields(&mut events).await;
2039    assert_eq!(
2040        update,
2041        acp::ToolCallUpdate {
2042            id: acp::ToolCallId("1".into()),
2043            fields: acp::ToolCallUpdateFields {
2044                title: Some("Thinking".into()),
2045                kind: Some(acp::ToolKind::Think),
2046                raw_input: Some(json!({ "content": "Thinking hard!" })),
2047                ..Default::default()
2048            },
2049            meta: None,
2050        }
2051    );
2052    let update = expect_tool_call_update_fields(&mut events).await;
2053    assert_eq!(
2054        update,
2055        acp::ToolCallUpdate {
2056            id: acp::ToolCallId("1".into()),
2057            fields: acp::ToolCallUpdateFields {
2058                status: Some(acp::ToolCallStatus::InProgress),
2059                ..Default::default()
2060            },
2061            meta: None,
2062        }
2063    );
2064    let update = expect_tool_call_update_fields(&mut events).await;
2065    assert_eq!(
2066        update,
2067        acp::ToolCallUpdate {
2068            id: acp::ToolCallId("1".into()),
2069            fields: acp::ToolCallUpdateFields {
2070                content: Some(vec!["Thinking hard!".into()]),
2071                ..Default::default()
2072            },
2073            meta: None,
2074        }
2075    );
2076    let update = expect_tool_call_update_fields(&mut events).await;
2077    assert_eq!(
2078        update,
2079        acp::ToolCallUpdate {
2080            id: acp::ToolCallId("1".into()),
2081            fields: acp::ToolCallUpdateFields {
2082                status: Some(acp::ToolCallStatus::Completed),
2083                raw_output: Some("Finished thinking.".into()),
2084                ..Default::default()
2085            },
2086            meta: None,
2087        }
2088    );
2089}
2090
2091#[gpui::test]
2092async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2093    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2094    let fake_model = model.as_fake();
2095
2096    let mut events = thread
2097        .update(cx, |thread, cx| {
2098            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2099            thread.send(UserMessageId::new(), ["Hello!"], cx)
2100        })
2101        .unwrap();
2102    cx.run_until_parked();
2103
2104    fake_model.send_last_completion_stream_text_chunk("Hey!");
2105    fake_model.end_last_completion_stream();
2106
2107    let mut retry_events = Vec::new();
2108    while let Some(Ok(event)) = events.next().await {
2109        match event {
2110            ThreadEvent::Retry(retry_status) => {
2111                retry_events.push(retry_status);
2112            }
2113            ThreadEvent::Stop(..) => break,
2114            _ => {}
2115        }
2116    }
2117
2118    assert_eq!(retry_events.len(), 0);
2119    thread.read_with(cx, |thread, _cx| {
2120        assert_eq!(
2121            thread.to_markdown(),
2122            indoc! {"
2123                ## User
2124
2125                Hello!
2126
2127                ## Assistant
2128
2129                Hey!
2130            "}
2131        )
2132    });
2133}
2134
2135#[gpui::test]
2136async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2137    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2138    let fake_model = model.as_fake();
2139
2140    let mut events = thread
2141        .update(cx, |thread, cx| {
2142            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2143            thread.send(UserMessageId::new(), ["Hello!"], cx)
2144        })
2145        .unwrap();
2146    cx.run_until_parked();
2147
2148    fake_model.send_last_completion_stream_text_chunk("Hey,");
2149    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2150        provider: LanguageModelProviderName::new("Anthropic"),
2151        retry_after: Some(Duration::from_secs(3)),
2152    });
2153    fake_model.end_last_completion_stream();
2154
2155    cx.executor().advance_clock(Duration::from_secs(3));
2156    cx.run_until_parked();
2157
2158    fake_model.send_last_completion_stream_text_chunk("there!");
2159    fake_model.end_last_completion_stream();
2160    cx.run_until_parked();
2161
2162    let mut retry_events = Vec::new();
2163    while let Some(Ok(event)) = events.next().await {
2164        match event {
2165            ThreadEvent::Retry(retry_status) => {
2166                retry_events.push(retry_status);
2167            }
2168            ThreadEvent::Stop(..) => break,
2169            _ => {}
2170        }
2171    }
2172
2173    assert_eq!(retry_events.len(), 1);
2174    assert!(matches!(
2175        retry_events[0],
2176        acp_thread::RetryStatus { attempt: 1, .. }
2177    ));
2178    thread.read_with(cx, |thread, _cx| {
2179        assert_eq!(
2180            thread.to_markdown(),
2181            indoc! {"
2182                ## User
2183
2184                Hello!
2185
2186                ## Assistant
2187
2188                Hey,
2189
2190                [resume]
2191
2192                ## Assistant
2193
2194                there!
2195            "}
2196        )
2197    });
2198}
2199
2200#[gpui::test]
2201async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2202    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2203    let fake_model = model.as_fake();
2204
2205    let events = thread
2206        .update(cx, |thread, cx| {
2207            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2208            thread.add_tool(EchoTool);
2209            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2210        })
2211        .unwrap();
2212    cx.run_until_parked();
2213
2214    let tool_use_1 = LanguageModelToolUse {
2215        id: "tool_1".into(),
2216        name: EchoTool::name().into(),
2217        raw_input: json!({"text": "test"}).to_string(),
2218        input: json!({"text": "test"}),
2219        is_input_complete: true,
2220    };
2221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2222        tool_use_1.clone(),
2223    ));
2224    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2225        provider: LanguageModelProviderName::new("Anthropic"),
2226        retry_after: Some(Duration::from_secs(3)),
2227    });
2228    fake_model.end_last_completion_stream();
2229
2230    cx.executor().advance_clock(Duration::from_secs(3));
2231    let completion = fake_model.pending_completions().pop().unwrap();
2232    assert_eq!(
2233        completion.messages[1..],
2234        vec![
2235            LanguageModelRequestMessage {
2236                role: Role::User,
2237                content: vec!["Call the echo tool!".into()],
2238                cache: false
2239            },
2240            LanguageModelRequestMessage {
2241                role: Role::Assistant,
2242                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2243                cache: false
2244            },
2245            LanguageModelRequestMessage {
2246                role: Role::User,
2247                content: vec![language_model::MessageContent::ToolResult(
2248                    LanguageModelToolResult {
2249                        tool_use_id: tool_use_1.id.clone(),
2250                        tool_name: tool_use_1.name.clone(),
2251                        is_error: false,
2252                        content: "test".into(),
2253                        output: Some("test".into())
2254                    }
2255                )],
2256                cache: true
2257            },
2258        ]
2259    );
2260
2261    fake_model.send_last_completion_stream_text_chunk("Done");
2262    fake_model.end_last_completion_stream();
2263    cx.run_until_parked();
2264    events.collect::<Vec<_>>().await;
2265    thread.read_with(cx, |thread, _cx| {
2266        assert_eq!(
2267            thread.last_message(),
2268            Some(Message::Agent(AgentMessage {
2269                content: vec![AgentMessageContent::Text("Done".into())],
2270                tool_results: IndexMap::default()
2271            }))
2272        );
2273    })
2274}
2275
2276#[gpui::test]
2277async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2278    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2279    let fake_model = model.as_fake();
2280
2281    let mut events = thread
2282        .update(cx, |thread, cx| {
2283            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2284            thread.send(UserMessageId::new(), ["Hello!"], cx)
2285        })
2286        .unwrap();
2287    cx.run_until_parked();
2288
2289    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2290        fake_model.send_last_completion_stream_error(
2291            LanguageModelCompletionError::ServerOverloaded {
2292                provider: LanguageModelProviderName::new("Anthropic"),
2293                retry_after: Some(Duration::from_secs(3)),
2294            },
2295        );
2296        fake_model.end_last_completion_stream();
2297        cx.executor().advance_clock(Duration::from_secs(3));
2298        cx.run_until_parked();
2299    }
2300
2301    let mut errors = Vec::new();
2302    let mut retry_events = Vec::new();
2303    while let Some(event) = events.next().await {
2304        match event {
2305            Ok(ThreadEvent::Retry(retry_status)) => {
2306                retry_events.push(retry_status);
2307            }
2308            Ok(ThreadEvent::Stop(..)) => break,
2309            Err(error) => errors.push(error),
2310            _ => {}
2311        }
2312    }
2313
2314    assert_eq!(
2315        retry_events.len(),
2316        crate::thread::MAX_RETRY_ATTEMPTS as usize
2317    );
2318    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2319        assert_eq!(retry_events[i].attempt, i + 1);
2320    }
2321    assert_eq!(errors.len(), 1);
2322    let error = errors[0]
2323        .downcast_ref::<LanguageModelCompletionError>()
2324        .unwrap();
2325    assert!(matches!(
2326        error,
2327        LanguageModelCompletionError::ServerOverloaded { .. }
2328    ));
2329}
2330
2331/// Filters out the stop events for asserting against in tests
2332fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2333    result_events
2334        .into_iter()
2335        .filter_map(|event| match event.unwrap() {
2336            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2337            _ => None,
2338        })
2339        .collect()
2340}
2341
2342struct ThreadTest {
2343    model: Arc<dyn LanguageModel>,
2344    thread: Entity<Thread>,
2345    project_context: Entity<ProjectContext>,
2346    context_server_store: Entity<ContextServerStore>,
2347    fs: Arc<FakeFs>,
2348}
2349
2350enum TestModel {
2351    Sonnet4,
2352    Fake,
2353}
2354
2355impl TestModel {
2356    fn id(&self) -> LanguageModelId {
2357        match self {
2358            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2359            TestModel::Fake => unreachable!(),
2360        }
2361    }
2362}
2363
2364async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2365    cx.executor().allow_parking();
2366
2367    let fs = FakeFs::new(cx.background_executor.clone());
2368    fs.create_dir(paths::settings_file().parent().unwrap())
2369        .await
2370        .unwrap();
2371    fs.insert_file(
2372        paths::settings_file(),
2373        json!({
2374            "agent": {
2375                "default_profile": "test-profile",
2376                "profiles": {
2377                    "test-profile": {
2378                        "name": "Test Profile",
2379                        "tools": {
2380                            EchoTool::name(): true,
2381                            DelayTool::name(): true,
2382                            WordListTool::name(): true,
2383                            ToolRequiringPermission::name(): true,
2384                            InfiniteTool::name(): true,
2385                            ThinkingTool::name(): true,
2386                        }
2387                    }
2388                }
2389            }
2390        })
2391        .to_string()
2392        .into_bytes(),
2393    )
2394    .await;
2395
2396    cx.update(|cx| {
2397        settings::init(cx);
2398        Project::init_settings(cx);
2399        agent_settings::init(cx);
2400
2401        match model {
2402            TestModel::Fake => {}
2403            TestModel::Sonnet4 => {
2404                gpui_tokio::init(cx);
2405                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2406                cx.set_http_client(Arc::new(http_client));
2407                client::init_settings(cx);
2408                let client = Client::production(cx);
2409                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2410                language_model::init(client.clone(), cx);
2411                language_models::init(user_store, client.clone(), cx);
2412            }
2413        };
2414
2415        watch_settings(fs.clone(), cx);
2416    });
2417
2418    let templates = Templates::new();
2419
2420    fs.insert_tree(path!("/test"), json!({})).await;
2421    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2422
2423    let model = cx
2424        .update(|cx| {
2425            if let TestModel::Fake = model {
2426                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2427            } else {
2428                let model_id = model.id();
2429                let models = LanguageModelRegistry::read_global(cx);
2430                let model = models
2431                    .available_models(cx)
2432                    .find(|model| model.id() == model_id)
2433                    .unwrap();
2434
2435                let provider = models.provider(&model.provider_id()).unwrap();
2436                let authenticated = provider.authenticate(cx);
2437
2438                cx.spawn(async move |_cx| {
2439                    authenticated.await.unwrap();
2440                    model
2441                })
2442            }
2443        })
2444        .await;
2445
2446    let project_context = cx.new(|_cx| ProjectContext::default());
2447    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2448    let context_server_registry =
2449        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2450    let thread = cx.new(|cx| {
2451        Thread::new(
2452            project,
2453            project_context.clone(),
2454            context_server_registry,
2455            templates,
2456            Some(model.clone()),
2457            cx,
2458        )
2459    });
2460    ThreadTest {
2461        model,
2462        thread,
2463        project_context,
2464        context_server_store,
2465        fs,
2466    }
2467}
2468
2469#[cfg(test)]
2470#[ctor::ctor]
2471fn init_logger() {
2472    if std::env::var("RUST_LOG").is_ok() {
2473        env_logger::init();
2474    }
2475}
2476
2477fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2478    let fs = fs.clone();
2479    cx.spawn({
2480        async move |cx| {
2481            let mut new_settings_content_rx = settings::watch_config_file(
2482                cx.background_executor(),
2483                fs,
2484                paths::settings_file().clone(),
2485            );
2486
2487            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2488                cx.update(|cx| {
2489                    SettingsStore::update_global(cx, |settings, cx| {
2490                        settings.set_user_settings(&new_settings_content, cx)
2491                    })
2492                })
2493                .ok();
2494            }
2495        }
2496    })
2497    .detach();
2498}
2499
2500fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2501    completion
2502        .tools
2503        .iter()
2504        .map(|tool| tool.name.clone())
2505        .collect()
2506}
2507
2508fn setup_context_server(
2509    name: &'static str,
2510    tools: Vec<context_server::types::Tool>,
2511    context_server_store: &Entity<ContextServerStore>,
2512    cx: &mut TestAppContext,
2513) -> mpsc::UnboundedReceiver<(
2514    context_server::types::CallToolParams,
2515    oneshot::Sender<context_server::types::CallToolResponse>,
2516)> {
2517    cx.update(|cx| {
2518        let mut settings = ProjectSettings::get_global(cx).clone();
2519        settings.context_servers.insert(
2520            name.into(),
2521            project::project_settings::ContextServerSettings::Custom {
2522                enabled: true,
2523                command: ContextServerCommand {
2524                    path: "somebinary".into(),
2525                    args: Vec::new(),
2526                    env: None,
2527                    timeout: None,
2528                },
2529            },
2530        );
2531        ProjectSettings::override_global(settings, cx);
2532    });
2533
2534    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2535    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2536        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2537            context_server::types::InitializeResponse {
2538                protocol_version: context_server::types::ProtocolVersion(
2539                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2540                ),
2541                server_info: context_server::types::Implementation {
2542                    name: name.into(),
2543                    version: "1.0.0".to_string(),
2544                },
2545                capabilities: context_server::types::ServerCapabilities {
2546                    tools: Some(context_server::types::ToolsCapabilities {
2547                        list_changed: Some(true),
2548                    }),
2549                    ..Default::default()
2550                },
2551                meta: None,
2552            }
2553        })
2554        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2555            let tools = tools.clone();
2556            async move {
2557                context_server::types::ListToolsResponse {
2558                    tools,
2559                    next_cursor: None,
2560                    meta: None,
2561                }
2562            }
2563        })
2564        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2565            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2566            async move {
2567                let (response_tx, response_rx) = oneshot::channel();
2568                mcp_tool_calls_tx
2569                    .unbounded_send((params, response_tx))
2570                    .unwrap();
2571                response_rx.await.unwrap()
2572            }
2573        });
2574    context_server_store.update(cx, |store, cx| {
2575        store.start_server(
2576            Arc::new(ContextServer::new(
2577                ContextServerId(name.into()),
2578                Arc::new(fake_transport),
2579            )),
2580            cx,
2581        );
2582    });
2583    cx.run_until_parked();
2584    mcp_tool_calls_rx
2585}