mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true,
 219            reasoning_details: None,
 220        }]
 221    );
 222    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 223        "Response to Message 1".into(),
 224    ));
 225    fake_model.end_last_completion_stream();
 226    cx.run_until_parked();
 227
 228    // Send another user message and verify only the latest is cached
 229    thread
 230        .update(cx, |thread, cx| {
 231            thread.send(UserMessageId::new(), ["Message 2"], cx)
 232        })
 233        .unwrap();
 234    cx.run_until_parked();
 235
 236    let completion = fake_model.pending_completions().pop().unwrap();
 237    assert_eq!(
 238        completion.messages[1..],
 239        vec![
 240            LanguageModelRequestMessage {
 241                role: Role::User,
 242                content: vec!["Message 1".into()],
 243                cache: false,
 244                reasoning_details: None,
 245            },
 246            LanguageModelRequestMessage {
 247                role: Role::Assistant,
 248                content: vec!["Response to Message 1".into()],
 249                cache: false,
 250                reasoning_details: None,
 251            },
 252            LanguageModelRequestMessage {
 253                role: Role::User,
 254                content: vec!["Message 2".into()],
 255                cache: true,
 256                reasoning_details: None,
 257            }
 258        ]
 259    );
 260    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 261        "Response to Message 2".into(),
 262    ));
 263    fake_model.end_last_completion_stream();
 264    cx.run_until_parked();
 265
 266    // Simulate a tool call and verify that the latest tool result is cached
 267    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 268    thread
 269        .update(cx, |thread, cx| {
 270            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 271        })
 272        .unwrap();
 273    cx.run_until_parked();
 274
 275    let tool_use = LanguageModelToolUse {
 276        id: "tool_1".into(),
 277        name: EchoTool::name().into(),
 278        raw_input: json!({"text": "test"}).to_string(),
 279        input: json!({"text": "test"}),
 280        is_input_complete: true,
 281        thought_signature: None,
 282    };
 283    fake_model
 284        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 285    fake_model.end_last_completion_stream();
 286    cx.run_until_parked();
 287
 288    let completion = fake_model.pending_completions().pop().unwrap();
 289    let tool_result = LanguageModelToolResult {
 290        tool_use_id: "tool_1".into(),
 291        tool_name: EchoTool::name().into(),
 292        is_error: false,
 293        content: "test".into(),
 294        output: Some("test".into()),
 295    };
 296    assert_eq!(
 297        completion.messages[1..],
 298        vec![
 299            LanguageModelRequestMessage {
 300                role: Role::User,
 301                content: vec!["Message 1".into()],
 302                cache: false,
 303                reasoning_details: None,
 304            },
 305            LanguageModelRequestMessage {
 306                role: Role::Assistant,
 307                content: vec!["Response to Message 1".into()],
 308                cache: false,
 309                reasoning_details: None,
 310            },
 311            LanguageModelRequestMessage {
 312                role: Role::User,
 313                content: vec!["Message 2".into()],
 314                cache: false,
 315                reasoning_details: None,
 316            },
 317            LanguageModelRequestMessage {
 318                role: Role::Assistant,
 319                content: vec!["Response to Message 2".into()],
 320                cache: false,
 321                reasoning_details: None,
 322            },
 323            LanguageModelRequestMessage {
 324                role: Role::User,
 325                content: vec!["Use the echo tool".into()],
 326                cache: false,
 327                reasoning_details: None,
 328            },
 329            LanguageModelRequestMessage {
 330                role: Role::Assistant,
 331                content: vec![MessageContent::ToolUse(tool_use)],
 332                cache: false,
 333                reasoning_details: None,
 334            },
 335            LanguageModelRequestMessage {
 336                role: Role::User,
 337                content: vec![MessageContent::ToolResult(tool_result)],
 338                cache: true,
 339                reasoning_details: None,
 340            }
 341        ]
 342    );
 343}
 344
 345#[gpui::test]
 346#[cfg_attr(not(feature = "e2e"), ignore)]
 347async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 348    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 349
 350    // Test a tool call that's likely to complete *before* streaming stops.
 351    let events = thread
 352        .update(cx, |thread, cx| {
 353            thread.add_tool(EchoTool);
 354            thread.send(
 355                UserMessageId::new(),
 356                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 357                cx,
 358            )
 359        })
 360        .unwrap()
 361        .collect()
 362        .await;
 363    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 364
 365    // Test a tool calls that's likely to complete *after* streaming stops.
 366    let events = thread
 367        .update(cx, |thread, cx| {
 368            thread.remove_tool(&EchoTool::name());
 369            thread.add_tool(DelayTool);
 370            thread.send(
 371                UserMessageId::new(),
 372                [
 373                    "Now call the delay tool with 200ms.",
 374                    "When the timer goes off, then you echo the output of the tool.",
 375                ],
 376                cx,
 377            )
 378        })
 379        .unwrap()
 380        .collect()
 381        .await;
 382    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 383    thread.update(cx, |thread, _cx| {
 384        assert!(
 385            thread
 386                .last_message()
 387                .unwrap()
 388                .as_agent_message()
 389                .unwrap()
 390                .content
 391                .iter()
 392                .any(|content| {
 393                    if let AgentMessageContent::Text(text) = content {
 394                        text.contains("Ding")
 395                    } else {
 396                        false
 397                    }
 398                }),
 399            "{}",
 400            thread.to_markdown()
 401        );
 402    });
 403}
 404
 405#[gpui::test]
 406#[cfg_attr(not(feature = "e2e"), ignore)]
 407async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 408    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 409
 410    // Test a tool call that's likely to complete *before* streaming stops.
 411    let mut events = thread
 412        .update(cx, |thread, cx| {
 413            thread.add_tool(WordListTool);
 414            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 415        })
 416        .unwrap();
 417
 418    let mut saw_partial_tool_use = false;
 419    while let Some(event) = events.next().await {
 420        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 421            thread.update(cx, |thread, _cx| {
 422                // Look for a tool use in the thread's last message
 423                let message = thread.last_message().unwrap();
 424                let agent_message = message.as_agent_message().unwrap();
 425                let last_content = agent_message.content.last().unwrap();
 426                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 427                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 428                    if tool_call.status == acp::ToolCallStatus::Pending {
 429                        if !last_tool_use.is_input_complete
 430                            && last_tool_use.input.get("g").is_none()
 431                        {
 432                            saw_partial_tool_use = true;
 433                        }
 434                    } else {
 435                        last_tool_use
 436                            .input
 437                            .get("a")
 438                            .expect("'a' has streamed because input is now complete");
 439                        last_tool_use
 440                            .input
 441                            .get("g")
 442                            .expect("'g' has streamed because input is now complete");
 443                    }
 444                } else {
 445                    panic!("last content should be a tool use");
 446                }
 447            });
 448        }
 449    }
 450
 451    assert!(
 452        saw_partial_tool_use,
 453        "should see at least one partially streamed tool use in the history"
 454    );
 455}
 456
 457#[gpui::test]
 458async fn test_tool_authorization(cx: &mut TestAppContext) {
 459    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 460    let fake_model = model.as_fake();
 461
 462    let mut events = thread
 463        .update(cx, |thread, cx| {
 464            thread.add_tool(ToolRequiringPermission);
 465            thread.send(UserMessageId::new(), ["abc"], cx)
 466        })
 467        .unwrap();
 468    cx.run_until_parked();
 469    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 470        LanguageModelToolUse {
 471            id: "tool_id_1".into(),
 472            name: ToolRequiringPermission::name().into(),
 473            raw_input: "{}".into(),
 474            input: json!({}),
 475            is_input_complete: true,
 476            thought_signature: None,
 477        },
 478    ));
 479    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 480        LanguageModelToolUse {
 481            id: "tool_id_2".into(),
 482            name: ToolRequiringPermission::name().into(),
 483            raw_input: "{}".into(),
 484            input: json!({}),
 485            is_input_complete: true,
 486            thought_signature: None,
 487        },
 488    ));
 489    fake_model.end_last_completion_stream();
 490    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 491    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 492
 493    // Approve the first
 494    tool_call_auth_1
 495        .response
 496        .send(tool_call_auth_1.options[1].id.clone())
 497        .unwrap();
 498    cx.run_until_parked();
 499
 500    // Reject the second
 501    tool_call_auth_2
 502        .response
 503        .send(tool_call_auth_1.options[2].id.clone())
 504        .unwrap();
 505    cx.run_until_parked();
 506
 507    let completion = fake_model.pending_completions().pop().unwrap();
 508    let message = completion.messages.last().unwrap();
 509    assert_eq!(
 510        message.content,
 511        vec![
 512            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 513                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 514                tool_name: ToolRequiringPermission::name().into(),
 515                is_error: false,
 516                content: "Allowed".into(),
 517                output: Some("Allowed".into())
 518            }),
 519            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 520                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 521                tool_name: ToolRequiringPermission::name().into(),
 522                is_error: true,
 523                content: "Permission to run tool denied by user".into(),
 524                output: Some("Permission to run tool denied by user".into())
 525            })
 526        ]
 527    );
 528
 529    // Simulate yet another tool call.
 530    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 531        LanguageModelToolUse {
 532            id: "tool_id_3".into(),
 533            name: ToolRequiringPermission::name().into(),
 534            raw_input: "{}".into(),
 535            input: json!({}),
 536            is_input_complete: true,
 537            thought_signature: None,
 538        },
 539    ));
 540    fake_model.end_last_completion_stream();
 541
 542    // Respond by always allowing tools.
 543    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 544    tool_call_auth_3
 545        .response
 546        .send(tool_call_auth_3.options[0].id.clone())
 547        .unwrap();
 548    cx.run_until_parked();
 549    let completion = fake_model.pending_completions().pop().unwrap();
 550    let message = completion.messages.last().unwrap();
 551    assert_eq!(
 552        message.content,
 553        vec![language_model::MessageContent::ToolResult(
 554            LanguageModelToolResult {
 555                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 556                tool_name: ToolRequiringPermission::name().into(),
 557                is_error: false,
 558                content: "Allowed".into(),
 559                output: Some("Allowed".into())
 560            }
 561        )]
 562    );
 563
 564    // Simulate a final tool call, ensuring we don't trigger authorization.
 565    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 566        LanguageModelToolUse {
 567            id: "tool_id_4".into(),
 568            name: ToolRequiringPermission::name().into(),
 569            raw_input: "{}".into(),
 570            input: json!({}),
 571            is_input_complete: true,
 572            thought_signature: None,
 573        },
 574    ));
 575    fake_model.end_last_completion_stream();
 576    cx.run_until_parked();
 577    let completion = fake_model.pending_completions().pop().unwrap();
 578    let message = completion.messages.last().unwrap();
 579    assert_eq!(
 580        message.content,
 581        vec![language_model::MessageContent::ToolResult(
 582            LanguageModelToolResult {
 583                tool_use_id: "tool_id_4".into(),
 584                tool_name: ToolRequiringPermission::name().into(),
 585                is_error: false,
 586                content: "Allowed".into(),
 587                output: Some("Allowed".into())
 588            }
 589        )]
 590    );
 591}
 592
 593#[gpui::test]
 594async fn test_tool_hallucination(cx: &mut TestAppContext) {
 595    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 596    let fake_model = model.as_fake();
 597
 598    let mut events = thread
 599        .update(cx, |thread, cx| {
 600            thread.send(UserMessageId::new(), ["abc"], cx)
 601        })
 602        .unwrap();
 603    cx.run_until_parked();
 604    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 605        LanguageModelToolUse {
 606            id: "tool_id_1".into(),
 607            name: "nonexistent_tool".into(),
 608            raw_input: "{}".into(),
 609            input: json!({}),
 610            is_input_complete: true,
 611            thought_signature: None,
 612        },
 613    ));
 614    fake_model.end_last_completion_stream();
 615
 616    let tool_call = expect_tool_call(&mut events).await;
 617    assert_eq!(tool_call.title, "nonexistent_tool");
 618    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 619    let update = expect_tool_call_update_fields(&mut events).await;
 620    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 621}
 622
 623#[gpui::test]
 624async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 625    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 626    let fake_model = model.as_fake();
 627
 628    let events = thread
 629        .update(cx, |thread, cx| {
 630            thread.add_tool(EchoTool);
 631            thread.send(UserMessageId::new(), ["abc"], cx)
 632        })
 633        .unwrap();
 634    cx.run_until_parked();
 635    let tool_use = LanguageModelToolUse {
 636        id: "tool_id_1".into(),
 637        name: EchoTool::name().into(),
 638        raw_input: "{}".into(),
 639        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 640        is_input_complete: true,
 641        thought_signature: None,
 642    };
 643    fake_model
 644        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 645    fake_model.end_last_completion_stream();
 646
 647    cx.run_until_parked();
 648    let completion = fake_model.pending_completions().pop().unwrap();
 649    let tool_result = LanguageModelToolResult {
 650        tool_use_id: "tool_id_1".into(),
 651        tool_name: EchoTool::name().into(),
 652        is_error: false,
 653        content: "def".into(),
 654        output: Some("def".into()),
 655    };
 656    assert_eq!(
 657        completion.messages[1..],
 658        vec![
 659            LanguageModelRequestMessage {
 660                role: Role::User,
 661                content: vec!["abc".into()],
 662                cache: false,
 663                reasoning_details: None,
 664            },
 665            LanguageModelRequestMessage {
 666                role: Role::Assistant,
 667                content: vec![MessageContent::ToolUse(tool_use.clone())],
 668                cache: false,
 669                reasoning_details: None,
 670            },
 671            LanguageModelRequestMessage {
 672                role: Role::User,
 673                content: vec![MessageContent::ToolResult(tool_result.clone())],
 674                cache: true,
 675                reasoning_details: None,
 676            },
 677        ]
 678    );
 679
 680    // Simulate reaching tool use limit.
 681    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 682    fake_model.end_last_completion_stream();
 683    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 684    assert!(
 685        last_event
 686            .unwrap_err()
 687            .is::<language_model::ToolUseLimitReachedError>()
 688    );
 689
 690    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 691    cx.run_until_parked();
 692    let completion = fake_model.pending_completions().pop().unwrap();
 693    assert_eq!(
 694        completion.messages[1..],
 695        vec![
 696            LanguageModelRequestMessage {
 697                role: Role::User,
 698                content: vec!["abc".into()],
 699                cache: false,
 700                reasoning_details: None,
 701            },
 702            LanguageModelRequestMessage {
 703                role: Role::Assistant,
 704                content: vec![MessageContent::ToolUse(tool_use)],
 705                cache: false,
 706                reasoning_details: None,
 707            },
 708            LanguageModelRequestMessage {
 709                role: Role::User,
 710                content: vec![MessageContent::ToolResult(tool_result)],
 711                cache: false,
 712                reasoning_details: None,
 713            },
 714            LanguageModelRequestMessage {
 715                role: Role::User,
 716                content: vec!["Continue where you left off".into()],
 717                cache: true,
 718                reasoning_details: None,
 719            }
 720        ]
 721    );
 722
 723    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 724    fake_model.end_last_completion_stream();
 725    events.collect::<Vec<_>>().await;
 726    thread.read_with(cx, |thread, _cx| {
 727        assert_eq!(
 728            thread.last_message().unwrap().to_markdown(),
 729            indoc! {"
 730                ## Assistant
 731
 732                Done
 733            "}
 734        )
 735    });
 736}
 737
 738#[gpui::test]
 739async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 740    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 741    let fake_model = model.as_fake();
 742
 743    let events = thread
 744        .update(cx, |thread, cx| {
 745            thread.add_tool(EchoTool);
 746            thread.send(UserMessageId::new(), ["abc"], cx)
 747        })
 748        .unwrap();
 749    cx.run_until_parked();
 750
 751    let tool_use = LanguageModelToolUse {
 752        id: "tool_id_1".into(),
 753        name: EchoTool::name().into(),
 754        raw_input: "{}".into(),
 755        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 756        is_input_complete: true,
 757        thought_signature: None,
 758    };
 759    let tool_result = LanguageModelToolResult {
 760        tool_use_id: "tool_id_1".into(),
 761        tool_name: EchoTool::name().into(),
 762        is_error: false,
 763        content: "def".into(),
 764        output: Some("def".into()),
 765    };
 766    fake_model
 767        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 768    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 769    fake_model.end_last_completion_stream();
 770    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 771    assert!(
 772        last_event
 773            .unwrap_err()
 774            .is::<language_model::ToolUseLimitReachedError>()
 775    );
 776
 777    thread
 778        .update(cx, |thread, cx| {
 779            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 780        })
 781        .unwrap();
 782    cx.run_until_parked();
 783    let completion = fake_model.pending_completions().pop().unwrap();
 784    assert_eq!(
 785        completion.messages[1..],
 786        vec![
 787            LanguageModelRequestMessage {
 788                role: Role::User,
 789                content: vec!["abc".into()],
 790                cache: false,
 791                reasoning_details: None,
 792            },
 793            LanguageModelRequestMessage {
 794                role: Role::Assistant,
 795                content: vec![MessageContent::ToolUse(tool_use)],
 796                cache: false,
 797                reasoning_details: None,
 798            },
 799            LanguageModelRequestMessage {
 800                role: Role::User,
 801                content: vec![MessageContent::ToolResult(tool_result)],
 802                cache: false,
 803                reasoning_details: None,
 804            },
 805            LanguageModelRequestMessage {
 806                role: Role::User,
 807                content: vec!["ghi".into()],
 808                cache: true,
 809                reasoning_details: None,
 810            }
 811        ]
 812    );
 813}
 814
 815async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 816    let event = events
 817        .next()
 818        .await
 819        .expect("no tool call authorization event received")
 820        .unwrap();
 821    match event {
 822        ThreadEvent::ToolCall(tool_call) => tool_call,
 823        event => {
 824            panic!("Unexpected event {event:?}");
 825        }
 826    }
 827}
 828
 829async fn expect_tool_call_update_fields(
 830    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 831) -> acp::ToolCallUpdate {
 832    let event = events
 833        .next()
 834        .await
 835        .expect("no tool call authorization event received")
 836        .unwrap();
 837    match event {
 838        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 839        event => {
 840            panic!("Unexpected event {event:?}");
 841        }
 842    }
 843}
 844
 845async fn next_tool_call_authorization(
 846    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 847) -> ToolCallAuthorization {
 848    loop {
 849        let event = events
 850            .next()
 851            .await
 852            .expect("no tool call authorization event received")
 853            .unwrap();
 854        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 855            let permission_kinds = tool_call_authorization
 856                .options
 857                .iter()
 858                .map(|o| o.kind)
 859                .collect::<Vec<_>>();
 860            assert_eq!(
 861                permission_kinds,
 862                vec![
 863                    acp::PermissionOptionKind::AllowAlways,
 864                    acp::PermissionOptionKind::AllowOnce,
 865                    acp::PermissionOptionKind::RejectOnce,
 866                ]
 867            );
 868            return tool_call_authorization;
 869        }
 870    }
 871}
 872
 873#[gpui::test]
 874#[cfg_attr(not(feature = "e2e"), ignore)]
 875async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 876    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 877
 878    // Test concurrent tool calls with different delay times
 879    let events = thread
 880        .update(cx, |thread, cx| {
 881            thread.add_tool(DelayTool);
 882            thread.send(
 883                UserMessageId::new(),
 884                [
 885                    "Call the delay tool twice in the same message.",
 886                    "Once with 100ms. Once with 300ms.",
 887                    "When both timers are complete, describe the outputs.",
 888                ],
 889                cx,
 890            )
 891        })
 892        .unwrap()
 893        .collect()
 894        .await;
 895
 896    let stop_reasons = stop_events(events);
 897    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 898
 899    thread.update(cx, |thread, _cx| {
 900        let last_message = thread.last_message().unwrap();
 901        let agent_message = last_message.as_agent_message().unwrap();
 902        let text = agent_message
 903            .content
 904            .iter()
 905            .filter_map(|content| {
 906                if let AgentMessageContent::Text(text) = content {
 907                    Some(text.as_str())
 908                } else {
 909                    None
 910                }
 911            })
 912            .collect::<String>();
 913
 914        assert!(text.contains("Ding"));
 915    });
 916}
 917
 918#[gpui::test]
 919async fn test_profiles(cx: &mut TestAppContext) {
 920    let ThreadTest {
 921        model, thread, fs, ..
 922    } = setup(cx, TestModel::Fake).await;
 923    let fake_model = model.as_fake();
 924
 925    thread.update(cx, |thread, _cx| {
 926        thread.add_tool(DelayTool);
 927        thread.add_tool(EchoTool);
 928        thread.add_tool(InfiniteTool);
 929    });
 930
 931    // Override profiles and wait for settings to be loaded.
 932    fs.insert_file(
 933        paths::settings_file(),
 934        json!({
 935            "agent": {
 936                "profiles": {
 937                    "test-1": {
 938                        "name": "Test Profile 1",
 939                        "tools": {
 940                            EchoTool::name(): true,
 941                            DelayTool::name(): true,
 942                        }
 943                    },
 944                    "test-2": {
 945                        "name": "Test Profile 2",
 946                        "tools": {
 947                            InfiniteTool::name(): true,
 948                        }
 949                    }
 950                }
 951            }
 952        })
 953        .to_string()
 954        .into_bytes(),
 955    )
 956    .await;
 957    cx.run_until_parked();
 958
 959    // Test that test-1 profile (default) has echo and delay tools
 960    thread
 961        .update(cx, |thread, cx| {
 962            thread.set_profile(AgentProfileId("test-1".into()), cx);
 963            thread.send(UserMessageId::new(), ["test"], cx)
 964        })
 965        .unwrap();
 966    cx.run_until_parked();
 967
 968    let mut pending_completions = fake_model.pending_completions();
 969    assert_eq!(pending_completions.len(), 1);
 970    let completion = pending_completions.pop().unwrap();
 971    let tool_names: Vec<String> = completion
 972        .tools
 973        .iter()
 974        .map(|tool| tool.name.clone())
 975        .collect();
 976    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 977    fake_model.end_last_completion_stream();
 978
 979    // Switch to test-2 profile, and verify that it has only the infinite tool.
 980    thread
 981        .update(cx, |thread, cx| {
 982            thread.set_profile(AgentProfileId("test-2".into()), cx);
 983            thread.send(UserMessageId::new(), ["test2"], cx)
 984        })
 985        .unwrap();
 986    cx.run_until_parked();
 987    let mut pending_completions = fake_model.pending_completions();
 988    assert_eq!(pending_completions.len(), 1);
 989    let completion = pending_completions.pop().unwrap();
 990    let tool_names: Vec<String> = completion
 991        .tools
 992        .iter()
 993        .map(|tool| tool.name.clone())
 994        .collect();
 995    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 996}
 997
 998#[gpui::test]
 999async fn test_mcp_tools(cx: &mut TestAppContext) {
1000    let ThreadTest {
1001        model,
1002        thread,
1003        context_server_store,
1004        fs,
1005        ..
1006    } = setup(cx, TestModel::Fake).await;
1007    let fake_model = model.as_fake();
1008
1009    // Override profiles and wait for settings to be loaded.
1010    fs.insert_file(
1011        paths::settings_file(),
1012        json!({
1013            "agent": {
1014                "always_allow_tool_actions": true,
1015                "profiles": {
1016                    "test": {
1017                        "name": "Test Profile",
1018                        "enable_all_context_servers": true,
1019                        "tools": {
1020                            EchoTool::name(): true,
1021                        }
1022                    },
1023                }
1024            }
1025        })
1026        .to_string()
1027        .into_bytes(),
1028    )
1029    .await;
1030    cx.run_until_parked();
1031    thread.update(cx, |thread, cx| {
1032        thread.set_profile(AgentProfileId("test".into()), cx)
1033    });
1034
1035    let mut mcp_tool_calls = setup_context_server(
1036        "test_server",
1037        vec![context_server::types::Tool {
1038            name: "echo".into(),
1039            description: None,
1040            input_schema: serde_json::to_value(EchoTool::input_schema(
1041                LanguageModelToolSchemaFormat::JsonSchema,
1042            ))
1043            .unwrap(),
1044            output_schema: None,
1045            annotations: None,
1046        }],
1047        &context_server_store,
1048        cx,
1049    );
1050
1051    let events = thread.update(cx, |thread, cx| {
1052        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1053    });
1054    cx.run_until_parked();
1055
1056    // Simulate the model calling the MCP tool.
1057    let completion = fake_model.pending_completions().pop().unwrap();
1058    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1059    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1060        LanguageModelToolUse {
1061            id: "tool_1".into(),
1062            name: "echo".into(),
1063            raw_input: json!({"text": "test"}).to_string(),
1064            input: json!({"text": "test"}),
1065            is_input_complete: true,
1066            thought_signature: None,
1067        },
1068    ));
1069    fake_model.end_last_completion_stream();
1070    cx.run_until_parked();
1071
1072    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1073    assert_eq!(tool_call_params.name, "echo");
1074    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1075    tool_call_response
1076        .send(context_server::types::CallToolResponse {
1077            content: vec![context_server::types::ToolResponseContent::Text {
1078                text: "test".into(),
1079            }],
1080            is_error: None,
1081            meta: None,
1082            structured_content: None,
1083        })
1084        .unwrap();
1085    cx.run_until_parked();
1086
1087    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1088    fake_model.send_last_completion_stream_text_chunk("Done!");
1089    fake_model.end_last_completion_stream();
1090    events.collect::<Vec<_>>().await;
1091
1092    // Send again after adding the echo tool, ensuring the name collision is resolved.
1093    let events = thread.update(cx, |thread, cx| {
1094        thread.add_tool(EchoTool);
1095        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1096    });
1097    cx.run_until_parked();
1098    let completion = fake_model.pending_completions().pop().unwrap();
1099    assert_eq!(
1100        tool_names_for_completion(&completion),
1101        vec!["echo", "test_server_echo"]
1102    );
1103    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1104        LanguageModelToolUse {
1105            id: "tool_2".into(),
1106            name: "test_server_echo".into(),
1107            raw_input: json!({"text": "mcp"}).to_string(),
1108            input: json!({"text": "mcp"}),
1109            is_input_complete: true,
1110            thought_signature: None,
1111        },
1112    ));
1113    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1114        LanguageModelToolUse {
1115            id: "tool_3".into(),
1116            name: "echo".into(),
1117            raw_input: json!({"text": "native"}).to_string(),
1118            input: json!({"text": "native"}),
1119            is_input_complete: true,
1120            thought_signature: None,
1121        },
1122    ));
1123    fake_model.end_last_completion_stream();
1124    cx.run_until_parked();
1125
1126    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1127    assert_eq!(tool_call_params.name, "echo");
1128    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1129    tool_call_response
1130        .send(context_server::types::CallToolResponse {
1131            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1132            is_error: None,
1133            meta: None,
1134            structured_content: None,
1135        })
1136        .unwrap();
1137    cx.run_until_parked();
1138
1139    // Ensure the tool results were inserted with the correct names.
1140    let completion = fake_model.pending_completions().pop().unwrap();
1141    assert_eq!(
1142        completion.messages.last().unwrap().content,
1143        vec![
1144            MessageContent::ToolResult(LanguageModelToolResult {
1145                tool_use_id: "tool_3".into(),
1146                tool_name: "echo".into(),
1147                is_error: false,
1148                content: "native".into(),
1149                output: Some("native".into()),
1150            },),
1151            MessageContent::ToolResult(LanguageModelToolResult {
1152                tool_use_id: "tool_2".into(),
1153                tool_name: "test_server_echo".into(),
1154                is_error: false,
1155                content: "mcp".into(),
1156                output: Some("mcp".into()),
1157            },),
1158        ]
1159    );
1160    fake_model.end_last_completion_stream();
1161    events.collect::<Vec<_>>().await;
1162}
1163
1164#[gpui::test]
1165async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1166    let ThreadTest {
1167        model,
1168        thread,
1169        context_server_store,
1170        fs,
1171        ..
1172    } = setup(cx, TestModel::Fake).await;
1173    let fake_model = model.as_fake();
1174
1175    // Set up a profile with all tools enabled
1176    fs.insert_file(
1177        paths::settings_file(),
1178        json!({
1179            "agent": {
1180                "profiles": {
1181                    "test": {
1182                        "name": "Test Profile",
1183                        "enable_all_context_servers": true,
1184                        "tools": {
1185                            EchoTool::name(): true,
1186                            DelayTool::name(): true,
1187                            WordListTool::name(): true,
1188                            ToolRequiringPermission::name(): true,
1189                            InfiniteTool::name(): true,
1190                        }
1191                    },
1192                }
1193            }
1194        })
1195        .to_string()
1196        .into_bytes(),
1197    )
1198    .await;
1199    cx.run_until_parked();
1200
1201    thread.update(cx, |thread, cx| {
1202        thread.set_profile(AgentProfileId("test".into()), cx);
1203        thread.add_tool(EchoTool);
1204        thread.add_tool(DelayTool);
1205        thread.add_tool(WordListTool);
1206        thread.add_tool(ToolRequiringPermission);
1207        thread.add_tool(InfiniteTool);
1208    });
1209
1210    // Set up multiple context servers with some overlapping tool names
1211    let _server1_calls = setup_context_server(
1212        "xxx",
1213        vec![
1214            context_server::types::Tool {
1215                name: "echo".into(), // Conflicts with native EchoTool
1216                description: None,
1217                input_schema: serde_json::to_value(EchoTool::input_schema(
1218                    LanguageModelToolSchemaFormat::JsonSchema,
1219                ))
1220                .unwrap(),
1221                output_schema: None,
1222                annotations: None,
1223            },
1224            context_server::types::Tool {
1225                name: "unique_tool_1".into(),
1226                description: None,
1227                input_schema: json!({"type": "object", "properties": {}}),
1228                output_schema: None,
1229                annotations: None,
1230            },
1231        ],
1232        &context_server_store,
1233        cx,
1234    );
1235
1236    let _server2_calls = setup_context_server(
1237        "yyy",
1238        vec![
1239            context_server::types::Tool {
1240                name: "echo".into(), // Also conflicts with native EchoTool
1241                description: None,
1242                input_schema: serde_json::to_value(EchoTool::input_schema(
1243                    LanguageModelToolSchemaFormat::JsonSchema,
1244                ))
1245                .unwrap(),
1246                output_schema: None,
1247                annotations: None,
1248            },
1249            context_server::types::Tool {
1250                name: "unique_tool_2".into(),
1251                description: None,
1252                input_schema: json!({"type": "object", "properties": {}}),
1253                output_schema: None,
1254                annotations: None,
1255            },
1256            context_server::types::Tool {
1257                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1258                description: None,
1259                input_schema: json!({"type": "object", "properties": {}}),
1260                output_schema: None,
1261                annotations: None,
1262            },
1263            context_server::types::Tool {
1264                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1265                description: None,
1266                input_schema: json!({"type": "object", "properties": {}}),
1267                output_schema: None,
1268                annotations: None,
1269            },
1270        ],
1271        &context_server_store,
1272        cx,
1273    );
1274    let _server3_calls = setup_context_server(
1275        "zzz",
1276        vec![
1277            context_server::types::Tool {
1278                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1279                description: None,
1280                input_schema: json!({"type": "object", "properties": {}}),
1281                output_schema: None,
1282                annotations: None,
1283            },
1284            context_server::types::Tool {
1285                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1286                description: None,
1287                input_schema: json!({"type": "object", "properties": {}}),
1288                output_schema: None,
1289                annotations: None,
1290            },
1291            context_server::types::Tool {
1292                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1293                description: None,
1294                input_schema: json!({"type": "object", "properties": {}}),
1295                output_schema: None,
1296                annotations: None,
1297            },
1298        ],
1299        &context_server_store,
1300        cx,
1301    );
1302
1303    thread
1304        .update(cx, |thread, cx| {
1305            thread.send(UserMessageId::new(), ["Go"], cx)
1306        })
1307        .unwrap();
1308    cx.run_until_parked();
1309    let completion = fake_model.pending_completions().pop().unwrap();
1310    assert_eq!(
1311        tool_names_for_completion(&completion),
1312        vec![
1313            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1314            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1315            "delay",
1316            "echo",
1317            "infinite",
1318            "tool_requiring_permission",
1319            "unique_tool_1",
1320            "unique_tool_2",
1321            "word_list",
1322            "xxx_echo",
1323            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1324            "yyy_echo",
1325            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1326        ]
1327    );
1328}
1329
1330#[gpui::test]
1331#[cfg_attr(not(feature = "e2e"), ignore)]
1332async fn test_cancellation(cx: &mut TestAppContext) {
1333    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1334
1335    let mut events = thread
1336        .update(cx, |thread, cx| {
1337            thread.add_tool(InfiniteTool);
1338            thread.add_tool(EchoTool);
1339            thread.send(
1340                UserMessageId::new(),
1341                ["Call the echo tool, then call the infinite tool, then explain their output"],
1342                cx,
1343            )
1344        })
1345        .unwrap();
1346
1347    // Wait until both tools are called.
1348    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1349    let mut echo_id = None;
1350    let mut echo_completed = false;
1351    while let Some(event) = events.next().await {
1352        match event.unwrap() {
1353            ThreadEvent::ToolCall(tool_call) => {
1354                assert_eq!(tool_call.title, expected_tools.remove(0));
1355                if tool_call.title == "Echo" {
1356                    echo_id = Some(tool_call.id);
1357                }
1358            }
1359            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1360                acp::ToolCallUpdate {
1361                    id,
1362                    fields:
1363                        acp::ToolCallUpdateFields {
1364                            status: Some(acp::ToolCallStatus::Completed),
1365                            ..
1366                        },
1367                    meta: None,
1368                },
1369            )) if Some(&id) == echo_id.as_ref() => {
1370                echo_completed = true;
1371            }
1372            _ => {}
1373        }
1374
1375        if expected_tools.is_empty() && echo_completed {
1376            break;
1377        }
1378    }
1379
1380    // Cancel the current send and ensure that the event stream is closed, even
1381    // if one of the tools is still running.
1382    thread.update(cx, |thread, cx| thread.cancel(cx));
1383    let events = events.collect::<Vec<_>>().await;
1384    let last_event = events.last();
1385    assert!(
1386        matches!(
1387            last_event,
1388            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1389        ),
1390        "unexpected event {last_event:?}"
1391    );
1392
1393    // Ensure we can still send a new message after cancellation.
1394    let events = thread
1395        .update(cx, |thread, cx| {
1396            thread.send(
1397                UserMessageId::new(),
1398                ["Testing: reply with 'Hello' then stop."],
1399                cx,
1400            )
1401        })
1402        .unwrap()
1403        .collect::<Vec<_>>()
1404        .await;
1405    thread.update(cx, |thread, _cx| {
1406        let message = thread.last_message().unwrap();
1407        let agent_message = message.as_agent_message().unwrap();
1408        assert_eq!(
1409            agent_message.content,
1410            vec![AgentMessageContent::Text("Hello".to_string())]
1411        );
1412    });
1413    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1414}
1415
1416#[gpui::test]
1417async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1418    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1419    let fake_model = model.as_fake();
1420
1421    let events_1 = thread
1422        .update(cx, |thread, cx| {
1423            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1424        })
1425        .unwrap();
1426    cx.run_until_parked();
1427    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1428    cx.run_until_parked();
1429
1430    let events_2 = thread
1431        .update(cx, |thread, cx| {
1432            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1433        })
1434        .unwrap();
1435    cx.run_until_parked();
1436    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1437    fake_model
1438        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1439    fake_model.end_last_completion_stream();
1440
1441    let events_1 = events_1.collect::<Vec<_>>().await;
1442    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1443    let events_2 = events_2.collect::<Vec<_>>().await;
1444    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1445}
1446
1447#[gpui::test]
1448async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1449    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1450    let fake_model = model.as_fake();
1451
1452    let events_1 = thread
1453        .update(cx, |thread, cx| {
1454            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1455        })
1456        .unwrap();
1457    cx.run_until_parked();
1458    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1459    fake_model
1460        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1461    fake_model.end_last_completion_stream();
1462    let events_1 = events_1.collect::<Vec<_>>().await;
1463
1464    let events_2 = thread
1465        .update(cx, |thread, cx| {
1466            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1467        })
1468        .unwrap();
1469    cx.run_until_parked();
1470    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1471    fake_model
1472        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1473    fake_model.end_last_completion_stream();
1474    let events_2 = events_2.collect::<Vec<_>>().await;
1475
1476    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1477    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1478}
1479
1480#[gpui::test]
1481async fn test_refusal(cx: &mut TestAppContext) {
1482    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1483    let fake_model = model.as_fake();
1484
1485    let events = thread
1486        .update(cx, |thread, cx| {
1487            thread.send(UserMessageId::new(), ["Hello"], cx)
1488        })
1489        .unwrap();
1490    cx.run_until_parked();
1491    thread.read_with(cx, |thread, _| {
1492        assert_eq!(
1493            thread.to_markdown(),
1494            indoc! {"
1495                ## User
1496
1497                Hello
1498            "}
1499        );
1500    });
1501
1502    fake_model.send_last_completion_stream_text_chunk("Hey!");
1503    cx.run_until_parked();
1504    thread.read_with(cx, |thread, _| {
1505        assert_eq!(
1506            thread.to_markdown(),
1507            indoc! {"
1508                ## User
1509
1510                Hello
1511
1512                ## Assistant
1513
1514                Hey!
1515            "}
1516        );
1517    });
1518
1519    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1520    fake_model
1521        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1522    let events = events.collect::<Vec<_>>().await;
1523    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1524    thread.read_with(cx, |thread, _| {
1525        assert_eq!(thread.to_markdown(), "");
1526    });
1527}
1528
1529#[gpui::test]
1530async fn test_truncate_first_message(cx: &mut TestAppContext) {
1531    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1532    let fake_model = model.as_fake();
1533
1534    let message_id = UserMessageId::new();
1535    thread
1536        .update(cx, |thread, cx| {
1537            thread.send(message_id.clone(), ["Hello"], cx)
1538        })
1539        .unwrap();
1540    cx.run_until_parked();
1541    thread.read_with(cx, |thread, _| {
1542        assert_eq!(
1543            thread.to_markdown(),
1544            indoc! {"
1545                ## User
1546
1547                Hello
1548            "}
1549        );
1550        assert_eq!(thread.latest_token_usage(), None);
1551    });
1552
1553    fake_model.send_last_completion_stream_text_chunk("Hey!");
1554    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1555        language_model::TokenUsage {
1556            input_tokens: 32_000,
1557            output_tokens: 16_000,
1558            cache_creation_input_tokens: 0,
1559            cache_read_input_tokens: 0,
1560        },
1561    ));
1562    cx.run_until_parked();
1563    thread.read_with(cx, |thread, _| {
1564        assert_eq!(
1565            thread.to_markdown(),
1566            indoc! {"
1567                ## User
1568
1569                Hello
1570
1571                ## Assistant
1572
1573                Hey!
1574            "}
1575        );
1576        assert_eq!(
1577            thread.latest_token_usage(),
1578            Some(acp_thread::TokenUsage {
1579                used_tokens: 32_000 + 16_000,
1580                max_tokens: 1_000_000,
1581            })
1582        );
1583    });
1584
1585    thread
1586        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1587        .unwrap();
1588    cx.run_until_parked();
1589    thread.read_with(cx, |thread, _| {
1590        assert_eq!(thread.to_markdown(), "");
1591        assert_eq!(thread.latest_token_usage(), None);
1592    });
1593
1594    // Ensure we can still send a new message after truncation.
1595    thread
1596        .update(cx, |thread, cx| {
1597            thread.send(UserMessageId::new(), ["Hi"], cx)
1598        })
1599        .unwrap();
1600    thread.update(cx, |thread, _cx| {
1601        assert_eq!(
1602            thread.to_markdown(),
1603            indoc! {"
1604                ## User
1605
1606                Hi
1607            "}
1608        );
1609    });
1610    cx.run_until_parked();
1611    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1612    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1613        language_model::TokenUsage {
1614            input_tokens: 40_000,
1615            output_tokens: 20_000,
1616            cache_creation_input_tokens: 0,
1617            cache_read_input_tokens: 0,
1618        },
1619    ));
1620    cx.run_until_parked();
1621    thread.read_with(cx, |thread, _| {
1622        assert_eq!(
1623            thread.to_markdown(),
1624            indoc! {"
1625                ## User
1626
1627                Hi
1628
1629                ## Assistant
1630
1631                Ahoy!
1632            "}
1633        );
1634
1635        assert_eq!(
1636            thread.latest_token_usage(),
1637            Some(acp_thread::TokenUsage {
1638                used_tokens: 40_000 + 20_000,
1639                max_tokens: 1_000_000,
1640            })
1641        );
1642    });
1643}
1644
1645#[gpui::test]
1646async fn test_truncate_second_message(cx: &mut TestAppContext) {
1647    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1648    let fake_model = model.as_fake();
1649
1650    thread
1651        .update(cx, |thread, cx| {
1652            thread.send(UserMessageId::new(), ["Message 1"], cx)
1653        })
1654        .unwrap();
1655    cx.run_until_parked();
1656    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1657    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1658        language_model::TokenUsage {
1659            input_tokens: 32_000,
1660            output_tokens: 16_000,
1661            cache_creation_input_tokens: 0,
1662            cache_read_input_tokens: 0,
1663        },
1664    ));
1665    fake_model.end_last_completion_stream();
1666    cx.run_until_parked();
1667
1668    let assert_first_message_state = |cx: &mut TestAppContext| {
1669        thread.clone().read_with(cx, |thread, _| {
1670            assert_eq!(
1671                thread.to_markdown(),
1672                indoc! {"
1673                    ## User
1674
1675                    Message 1
1676
1677                    ## Assistant
1678
1679                    Message 1 response
1680                "}
1681            );
1682
1683            assert_eq!(
1684                thread.latest_token_usage(),
1685                Some(acp_thread::TokenUsage {
1686                    used_tokens: 32_000 + 16_000,
1687                    max_tokens: 1_000_000,
1688                })
1689            );
1690        });
1691    };
1692
1693    assert_first_message_state(cx);
1694
1695    let second_message_id = UserMessageId::new();
1696    thread
1697        .update(cx, |thread, cx| {
1698            thread.send(second_message_id.clone(), ["Message 2"], cx)
1699        })
1700        .unwrap();
1701    cx.run_until_parked();
1702
1703    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1704    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1705        language_model::TokenUsage {
1706            input_tokens: 40_000,
1707            output_tokens: 20_000,
1708            cache_creation_input_tokens: 0,
1709            cache_read_input_tokens: 0,
1710        },
1711    ));
1712    fake_model.end_last_completion_stream();
1713    cx.run_until_parked();
1714
1715    thread.read_with(cx, |thread, _| {
1716        assert_eq!(
1717            thread.to_markdown(),
1718            indoc! {"
1719                ## User
1720
1721                Message 1
1722
1723                ## Assistant
1724
1725                Message 1 response
1726
1727                ## User
1728
1729                Message 2
1730
1731                ## Assistant
1732
1733                Message 2 response
1734            "}
1735        );
1736
1737        assert_eq!(
1738            thread.latest_token_usage(),
1739            Some(acp_thread::TokenUsage {
1740                used_tokens: 40_000 + 20_000,
1741                max_tokens: 1_000_000,
1742            })
1743        );
1744    });
1745
1746    thread
1747        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1748        .unwrap();
1749    cx.run_until_parked();
1750
1751    assert_first_message_state(cx);
1752}
1753
1754#[gpui::test]
1755async fn test_title_generation(cx: &mut TestAppContext) {
1756    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1757    let fake_model = model.as_fake();
1758
1759    let summary_model = Arc::new(FakeLanguageModel::default());
1760    thread.update(cx, |thread, cx| {
1761        thread.set_summarization_model(Some(summary_model.clone()), cx)
1762    });
1763
1764    let send = thread
1765        .update(cx, |thread, cx| {
1766            thread.send(UserMessageId::new(), ["Hello"], cx)
1767        })
1768        .unwrap();
1769    cx.run_until_parked();
1770
1771    fake_model.send_last_completion_stream_text_chunk("Hey!");
1772    fake_model.end_last_completion_stream();
1773    cx.run_until_parked();
1774    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1775
1776    // Ensure the summary model has been invoked to generate a title.
1777    summary_model.send_last_completion_stream_text_chunk("Hello ");
1778    summary_model.send_last_completion_stream_text_chunk("world\nG");
1779    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1780    summary_model.end_last_completion_stream();
1781    send.collect::<Vec<_>>().await;
1782    cx.run_until_parked();
1783    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1784
1785    // Send another message, ensuring no title is generated this time.
1786    let send = thread
1787        .update(cx, |thread, cx| {
1788            thread.send(UserMessageId::new(), ["Hello again"], cx)
1789        })
1790        .unwrap();
1791    cx.run_until_parked();
1792    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1793    fake_model.end_last_completion_stream();
1794    cx.run_until_parked();
1795    assert_eq!(summary_model.pending_completions(), Vec::new());
1796    send.collect::<Vec<_>>().await;
1797    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1798}
1799
1800#[gpui::test]
1801async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1802    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1803    let fake_model = model.as_fake();
1804
1805    let _events = thread
1806        .update(cx, |thread, cx| {
1807            thread.add_tool(ToolRequiringPermission);
1808            thread.add_tool(EchoTool);
1809            thread.send(UserMessageId::new(), ["Hey!"], cx)
1810        })
1811        .unwrap();
1812    cx.run_until_parked();
1813
1814    let permission_tool_use = LanguageModelToolUse {
1815        id: "tool_id_1".into(),
1816        name: ToolRequiringPermission::name().into(),
1817        raw_input: "{}".into(),
1818        input: json!({}),
1819        is_input_complete: true,
1820        thought_signature: None,
1821    };
1822    let echo_tool_use = LanguageModelToolUse {
1823        id: "tool_id_2".into(),
1824        name: EchoTool::name().into(),
1825        raw_input: json!({"text": "test"}).to_string(),
1826        input: json!({"text": "test"}),
1827        is_input_complete: true,
1828        thought_signature: None,
1829    };
1830    fake_model.send_last_completion_stream_text_chunk("Hi!");
1831    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1832        permission_tool_use,
1833    ));
1834    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1835        echo_tool_use.clone(),
1836    ));
1837    fake_model.end_last_completion_stream();
1838    cx.run_until_parked();
1839
1840    // Ensure pending tools are skipped when building a request.
1841    let request = thread
1842        .read_with(cx, |thread, cx| {
1843            thread.build_completion_request(CompletionIntent::EditFile, cx)
1844        })
1845        .unwrap();
1846    assert_eq!(
1847        request.messages[1..],
1848        vec![
1849            LanguageModelRequestMessage {
1850                role: Role::User,
1851                content: vec!["Hey!".into()],
1852                cache: true,
1853                reasoning_details: None,
1854            },
1855            LanguageModelRequestMessage {
1856                role: Role::Assistant,
1857                content: vec![
1858                    MessageContent::Text("Hi!".into()),
1859                    MessageContent::ToolUse(echo_tool_use.clone())
1860                ],
1861                cache: false,
1862                reasoning_details: None,
1863            },
1864            LanguageModelRequestMessage {
1865                role: Role::User,
1866                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1867                    tool_use_id: echo_tool_use.id.clone(),
1868                    tool_name: echo_tool_use.name,
1869                    is_error: false,
1870                    content: "test".into(),
1871                    output: Some("test".into())
1872                })],
1873                cache: false,
1874                reasoning_details: None,
1875            },
1876        ],
1877    );
1878}
1879
1880#[gpui::test]
1881async fn test_agent_connection(cx: &mut TestAppContext) {
1882    cx.update(settings::init);
1883    let templates = Templates::new();
1884
1885    // Initialize language model system with test provider
1886    cx.update(|cx| {
1887        gpui_tokio::init(cx);
1888
1889        let http_client = FakeHttpClient::with_404_response();
1890        let clock = Arc::new(clock::FakeSystemClock::new());
1891        let client = Client::new(clock, http_client, cx);
1892        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1893        language_model::init(client.clone(), cx);
1894        language_models::init(user_store, client.clone(), cx);
1895        LanguageModelRegistry::test(cx);
1896    });
1897    cx.executor().forbid_parking();
1898
1899    // Create a project for new_thread
1900    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1901    fake_fs.insert_tree(path!("/test"), json!({})).await;
1902    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1903    let cwd = Path::new("/test");
1904    let text_thread_store =
1905        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1906    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1907
1908    // Create agent and connection
1909    let agent = NativeAgent::new(
1910        project.clone(),
1911        history_store,
1912        templates.clone(),
1913        None,
1914        fake_fs.clone(),
1915        &mut cx.to_async(),
1916    )
1917    .await
1918    .unwrap();
1919    let connection = NativeAgentConnection(agent.clone());
1920
1921    // Create a thread using new_thread
1922    let connection_rc = Rc::new(connection.clone());
1923    let acp_thread = cx
1924        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1925        .await
1926        .expect("new_thread should succeed");
1927
1928    // Get the session_id from the AcpThread
1929    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1930
1931    // Test model_selector returns Some
1932    let selector_opt = connection.model_selector(&session_id);
1933    assert!(
1934        selector_opt.is_some(),
1935        "agent should always support ModelSelector"
1936    );
1937    let selector = selector_opt.unwrap();
1938
1939    // Test list_models
1940    let listed_models = cx
1941        .update(|cx| selector.list_models(cx))
1942        .await
1943        .expect("list_models should succeed");
1944    let AgentModelList::Grouped(listed_models) = listed_models else {
1945        panic!("Unexpected model list type");
1946    };
1947    assert!(!listed_models.is_empty(), "should have at least one model");
1948    assert_eq!(
1949        listed_models[&AgentModelGroupName("Fake".into())][0]
1950            .id
1951            .0
1952            .as_ref(),
1953        "fake/fake"
1954    );
1955
1956    // Test selected_model returns the default
1957    let model = cx
1958        .update(|cx| selector.selected_model(cx))
1959        .await
1960        .expect("selected_model should succeed");
1961    let model = cx
1962        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1963        .unwrap();
1964    let model = model.as_fake();
1965    assert_eq!(model.id().0, "fake", "should return default model");
1966
1967    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1968    cx.run_until_parked();
1969    model.send_last_completion_stream_text_chunk("def");
1970    cx.run_until_parked();
1971    acp_thread.read_with(cx, |thread, cx| {
1972        assert_eq!(
1973            thread.to_markdown(cx),
1974            indoc! {"
1975                ## User
1976
1977                abc
1978
1979                ## Assistant
1980
1981                def
1982
1983            "}
1984        )
1985    });
1986
1987    // Test cancel
1988    cx.update(|cx| connection.cancel(&session_id, cx));
1989    request.await.expect("prompt should fail gracefully");
1990
1991    // Ensure that dropping the ACP thread causes the native thread to be
1992    // dropped as well.
1993    cx.update(|_| drop(acp_thread));
1994    let result = cx
1995        .update(|cx| {
1996            connection.prompt(
1997                Some(acp_thread::UserMessageId::new()),
1998                acp::PromptRequest {
1999                    session_id: session_id.clone(),
2000                    prompt: vec!["ghi".into()],
2001                    meta: None,
2002                },
2003                cx,
2004            )
2005        })
2006        .await;
2007    assert_eq!(
2008        result.as_ref().unwrap_err().to_string(),
2009        "Session not found",
2010        "unexpected result: {:?}",
2011        result
2012    );
2013}
2014
2015#[gpui::test]
2016async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2017    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2018    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2019    let fake_model = model.as_fake();
2020
2021    let mut events = thread
2022        .update(cx, |thread, cx| {
2023            thread.send(UserMessageId::new(), ["Think"], cx)
2024        })
2025        .unwrap();
2026    cx.run_until_parked();
2027
2028    // Simulate streaming partial input.
2029    let input = json!({});
2030    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2031        LanguageModelToolUse {
2032            id: "1".into(),
2033            name: ThinkingTool::name().into(),
2034            raw_input: input.to_string(),
2035            input,
2036            is_input_complete: false,
2037            thought_signature: None,
2038        },
2039    ));
2040
2041    // Input streaming completed
2042    let input = json!({ "content": "Thinking hard!" });
2043    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2044        LanguageModelToolUse {
2045            id: "1".into(),
2046            name: "thinking".into(),
2047            raw_input: input.to_string(),
2048            input,
2049            is_input_complete: true,
2050            thought_signature: None,
2051        },
2052    ));
2053    fake_model.end_last_completion_stream();
2054    cx.run_until_parked();
2055
2056    let tool_call = expect_tool_call(&mut events).await;
2057    assert_eq!(
2058        tool_call,
2059        acp::ToolCall {
2060            id: acp::ToolCallId("1".into()),
2061            title: "Thinking".into(),
2062            kind: acp::ToolKind::Think,
2063            status: acp::ToolCallStatus::Pending,
2064            content: vec![],
2065            locations: vec![],
2066            raw_input: Some(json!({})),
2067            raw_output: None,
2068            meta: Some(json!({ "tool_name": "thinking" })),
2069        }
2070    );
2071    let update = expect_tool_call_update_fields(&mut events).await;
2072    assert_eq!(
2073        update,
2074        acp::ToolCallUpdate {
2075            id: acp::ToolCallId("1".into()),
2076            fields: acp::ToolCallUpdateFields {
2077                title: Some("Thinking".into()),
2078                kind: Some(acp::ToolKind::Think),
2079                raw_input: Some(json!({ "content": "Thinking hard!" })),
2080                ..Default::default()
2081            },
2082            meta: None,
2083        }
2084    );
2085    let update = expect_tool_call_update_fields(&mut events).await;
2086    assert_eq!(
2087        update,
2088        acp::ToolCallUpdate {
2089            id: acp::ToolCallId("1".into()),
2090            fields: acp::ToolCallUpdateFields {
2091                status: Some(acp::ToolCallStatus::InProgress),
2092                ..Default::default()
2093            },
2094            meta: None,
2095        }
2096    );
2097    let update = expect_tool_call_update_fields(&mut events).await;
2098    assert_eq!(
2099        update,
2100        acp::ToolCallUpdate {
2101            id: acp::ToolCallId("1".into()),
2102            fields: acp::ToolCallUpdateFields {
2103                content: Some(vec!["Thinking hard!".into()]),
2104                ..Default::default()
2105            },
2106            meta: None,
2107        }
2108    );
2109    let update = expect_tool_call_update_fields(&mut events).await;
2110    assert_eq!(
2111        update,
2112        acp::ToolCallUpdate {
2113            id: acp::ToolCallId("1".into()),
2114            fields: acp::ToolCallUpdateFields {
2115                status: Some(acp::ToolCallStatus::Completed),
2116                raw_output: Some("Finished thinking.".into()),
2117                ..Default::default()
2118            },
2119            meta: None,
2120        }
2121    );
2122}
2123
2124#[gpui::test]
2125async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2126    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2127    let fake_model = model.as_fake();
2128
2129    let mut events = thread
2130        .update(cx, |thread, cx| {
2131            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2132            thread.send(UserMessageId::new(), ["Hello!"], cx)
2133        })
2134        .unwrap();
2135    cx.run_until_parked();
2136
2137    fake_model.send_last_completion_stream_text_chunk("Hey!");
2138    fake_model.end_last_completion_stream();
2139
2140    let mut retry_events = Vec::new();
2141    while let Some(Ok(event)) = events.next().await {
2142        match event {
2143            ThreadEvent::Retry(retry_status) => {
2144                retry_events.push(retry_status);
2145            }
2146            ThreadEvent::Stop(..) => break,
2147            _ => {}
2148        }
2149    }
2150
2151    assert_eq!(retry_events.len(), 0);
2152    thread.read_with(cx, |thread, _cx| {
2153        assert_eq!(
2154            thread.to_markdown(),
2155            indoc! {"
2156                ## User
2157
2158                Hello!
2159
2160                ## Assistant
2161
2162                Hey!
2163            "}
2164        )
2165    });
2166}
2167
2168#[gpui::test]
2169async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2170    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2171    let fake_model = model.as_fake();
2172
2173    let mut events = thread
2174        .update(cx, |thread, cx| {
2175            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2176            thread.send(UserMessageId::new(), ["Hello!"], cx)
2177        })
2178        .unwrap();
2179    cx.run_until_parked();
2180
2181    fake_model.send_last_completion_stream_text_chunk("Hey,");
2182    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2183        provider: LanguageModelProviderName::new("Anthropic"),
2184        retry_after: Some(Duration::from_secs(3)),
2185    });
2186    fake_model.end_last_completion_stream();
2187
2188    cx.executor().advance_clock(Duration::from_secs(3));
2189    cx.run_until_parked();
2190
2191    fake_model.send_last_completion_stream_text_chunk("there!");
2192    fake_model.end_last_completion_stream();
2193    cx.run_until_parked();
2194
2195    let mut retry_events = Vec::new();
2196    while let Some(Ok(event)) = events.next().await {
2197        match event {
2198            ThreadEvent::Retry(retry_status) => {
2199                retry_events.push(retry_status);
2200            }
2201            ThreadEvent::Stop(..) => break,
2202            _ => {}
2203        }
2204    }
2205
2206    assert_eq!(retry_events.len(), 1);
2207    assert!(matches!(
2208        retry_events[0],
2209        acp_thread::RetryStatus { attempt: 1, .. }
2210    ));
2211    thread.read_with(cx, |thread, _cx| {
2212        assert_eq!(
2213            thread.to_markdown(),
2214            indoc! {"
2215                ## User
2216
2217                Hello!
2218
2219                ## Assistant
2220
2221                Hey,
2222
2223                [resume]
2224
2225                ## Assistant
2226
2227                there!
2228            "}
2229        )
2230    });
2231}
2232
2233#[gpui::test]
2234async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2235    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2236    let fake_model = model.as_fake();
2237
2238    let events = thread
2239        .update(cx, |thread, cx| {
2240            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2241            thread.add_tool(EchoTool);
2242            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2243        })
2244        .unwrap();
2245    cx.run_until_parked();
2246
2247    let tool_use_1 = LanguageModelToolUse {
2248        id: "tool_1".into(),
2249        name: EchoTool::name().into(),
2250        raw_input: json!({"text": "test"}).to_string(),
2251        input: json!({"text": "test"}),
2252        is_input_complete: true,
2253        thought_signature: None,
2254    };
2255    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2256        tool_use_1.clone(),
2257    ));
2258    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2259        provider: LanguageModelProviderName::new("Anthropic"),
2260        retry_after: Some(Duration::from_secs(3)),
2261    });
2262    fake_model.end_last_completion_stream();
2263
2264    cx.executor().advance_clock(Duration::from_secs(3));
2265    let completion = fake_model.pending_completions().pop().unwrap();
2266    assert_eq!(
2267        completion.messages[1..],
2268        vec![
2269            LanguageModelRequestMessage {
2270                role: Role::User,
2271                content: vec!["Call the echo tool!".into()],
2272                cache: false,
2273                reasoning_details: None,
2274            },
2275            LanguageModelRequestMessage {
2276                role: Role::Assistant,
2277                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2278                cache: false,
2279                reasoning_details: None,
2280            },
2281            LanguageModelRequestMessage {
2282                role: Role::User,
2283                content: vec![language_model::MessageContent::ToolResult(
2284                    LanguageModelToolResult {
2285                        tool_use_id: tool_use_1.id.clone(),
2286                        tool_name: tool_use_1.name.clone(),
2287                        is_error: false,
2288                        content: "test".into(),
2289                        output: Some("test".into())
2290                    }
2291                )],
2292                cache: true,
2293                reasoning_details: None,
2294            },
2295        ]
2296    );
2297
2298    fake_model.send_last_completion_stream_text_chunk("Done");
2299    fake_model.end_last_completion_stream();
2300    cx.run_until_parked();
2301    events.collect::<Vec<_>>().await;
2302    thread.read_with(cx, |thread, _cx| {
2303        assert_eq!(
2304            thread.last_message(),
2305            Some(Message::Agent(AgentMessage {
2306                content: vec![AgentMessageContent::Text("Done".into())],
2307                tool_results: IndexMap::default(),
2308                reasoning_details: None,
2309            }))
2310        );
2311    })
2312}
2313
2314#[gpui::test]
2315async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2316    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2317    let fake_model = model.as_fake();
2318
2319    let mut events = thread
2320        .update(cx, |thread, cx| {
2321            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2322            thread.send(UserMessageId::new(), ["Hello!"], cx)
2323        })
2324        .unwrap();
2325    cx.run_until_parked();
2326
2327    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2328        fake_model.send_last_completion_stream_error(
2329            LanguageModelCompletionError::ServerOverloaded {
2330                provider: LanguageModelProviderName::new("Anthropic"),
2331                retry_after: Some(Duration::from_secs(3)),
2332            },
2333        );
2334        fake_model.end_last_completion_stream();
2335        cx.executor().advance_clock(Duration::from_secs(3));
2336        cx.run_until_parked();
2337    }
2338
2339    let mut errors = Vec::new();
2340    let mut retry_events = Vec::new();
2341    while let Some(event) = events.next().await {
2342        match event {
2343            Ok(ThreadEvent::Retry(retry_status)) => {
2344                retry_events.push(retry_status);
2345            }
2346            Ok(ThreadEvent::Stop(..)) => break,
2347            Err(error) => errors.push(error),
2348            _ => {}
2349        }
2350    }
2351
2352    assert_eq!(
2353        retry_events.len(),
2354        crate::thread::MAX_RETRY_ATTEMPTS as usize
2355    );
2356    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2357        assert_eq!(retry_events[i].attempt, i + 1);
2358    }
2359    assert_eq!(errors.len(), 1);
2360    let error = errors[0]
2361        .downcast_ref::<LanguageModelCompletionError>()
2362        .unwrap();
2363    assert!(matches!(
2364        error,
2365        LanguageModelCompletionError::ServerOverloaded { .. }
2366    ));
2367}
2368
2369/// Filters out the stop events for asserting against in tests
2370fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2371    result_events
2372        .into_iter()
2373        .filter_map(|event| match event.unwrap() {
2374            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2375            _ => None,
2376        })
2377        .collect()
2378}
2379
2380struct ThreadTest {
2381    model: Arc<dyn LanguageModel>,
2382    thread: Entity<Thread>,
2383    project_context: Entity<ProjectContext>,
2384    context_server_store: Entity<ContextServerStore>,
2385    fs: Arc<FakeFs>,
2386}
2387
2388enum TestModel {
2389    Sonnet4,
2390    Fake,
2391}
2392
2393impl TestModel {
2394    fn id(&self) -> LanguageModelId {
2395        match self {
2396            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2397            TestModel::Fake => unreachable!(),
2398        }
2399    }
2400}
2401
2402async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2403    cx.executor().allow_parking();
2404
2405    let fs = FakeFs::new(cx.background_executor.clone());
2406    fs.create_dir(paths::settings_file().parent().unwrap())
2407        .await
2408        .unwrap();
2409    fs.insert_file(
2410        paths::settings_file(),
2411        json!({
2412            "agent": {
2413                "default_profile": "test-profile",
2414                "profiles": {
2415                    "test-profile": {
2416                        "name": "Test Profile",
2417                        "tools": {
2418                            EchoTool::name(): true,
2419                            DelayTool::name(): true,
2420                            WordListTool::name(): true,
2421                            ToolRequiringPermission::name(): true,
2422                            InfiniteTool::name(): true,
2423                            ThinkingTool::name(): true,
2424                        }
2425                    }
2426                }
2427            }
2428        })
2429        .to_string()
2430        .into_bytes(),
2431    )
2432    .await;
2433
2434    cx.update(|cx| {
2435        settings::init(cx);
2436
2437        match model {
2438            TestModel::Fake => {}
2439            TestModel::Sonnet4 => {
2440                gpui_tokio::init(cx);
2441                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2442                cx.set_http_client(Arc::new(http_client));
2443                let client = Client::production(cx);
2444                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2445                language_model::init(client.clone(), cx);
2446                language_models::init(user_store, client.clone(), cx);
2447            }
2448        };
2449
2450        watch_settings(fs.clone(), cx);
2451    });
2452
2453    let templates = Templates::new();
2454
2455    fs.insert_tree(path!("/test"), json!({})).await;
2456    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2457
2458    let model = cx
2459        .update(|cx| {
2460            if let TestModel::Fake = model {
2461                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2462            } else {
2463                let model_id = model.id();
2464                let models = LanguageModelRegistry::read_global(cx);
2465                let model = models
2466                    .available_models(cx)
2467                    .find(|model| model.id() == model_id)
2468                    .unwrap();
2469
2470                let provider = models.provider(&model.provider_id()).unwrap();
2471                let authenticated = provider.authenticate(cx);
2472
2473                cx.spawn(async move |_cx| {
2474                    authenticated.await.unwrap();
2475                    model
2476                })
2477            }
2478        })
2479        .await;
2480
2481    let project_context = cx.new(|_cx| ProjectContext::default());
2482    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2483    let context_server_registry =
2484        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2485    let thread = cx.new(|cx| {
2486        Thread::new(
2487            project,
2488            project_context.clone(),
2489            context_server_registry,
2490            templates,
2491            Some(model.clone()),
2492            cx,
2493        )
2494    });
2495    ThreadTest {
2496        model,
2497        thread,
2498        project_context,
2499        context_server_store,
2500        fs,
2501    }
2502}
2503
2504#[cfg(test)]
2505#[ctor::ctor]
2506fn init_logger() {
2507    if std::env::var("RUST_LOG").is_ok() {
2508        env_logger::init();
2509    }
2510}
2511
2512fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2513    let fs = fs.clone();
2514    cx.spawn({
2515        async move |cx| {
2516            let mut new_settings_content_rx = settings::watch_config_file(
2517                cx.background_executor(),
2518                fs,
2519                paths::settings_file().clone(),
2520            );
2521
2522            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2523                cx.update(|cx| {
2524                    SettingsStore::update_global(cx, |settings, cx| {
2525                        settings.set_user_settings(&new_settings_content, cx)
2526                    })
2527                })
2528                .ok();
2529            }
2530        }
2531    })
2532    .detach();
2533}
2534
2535fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2536    completion
2537        .tools
2538        .iter()
2539        .map(|tool| tool.name.clone())
2540        .collect()
2541}
2542
2543fn setup_context_server(
2544    name: &'static str,
2545    tools: Vec<context_server::types::Tool>,
2546    context_server_store: &Entity<ContextServerStore>,
2547    cx: &mut TestAppContext,
2548) -> mpsc::UnboundedReceiver<(
2549    context_server::types::CallToolParams,
2550    oneshot::Sender<context_server::types::CallToolResponse>,
2551)> {
2552    cx.update(|cx| {
2553        let mut settings = ProjectSettings::get_global(cx).clone();
2554        settings.context_servers.insert(
2555            name.into(),
2556            project::project_settings::ContextServerSettings::Stdio {
2557                enabled: true,
2558                command: ContextServerCommand {
2559                    path: "somebinary".into(),
2560                    args: Vec::new(),
2561                    env: None,
2562                    timeout: None,
2563                },
2564            },
2565        );
2566        ProjectSettings::override_global(settings, cx);
2567    });
2568
2569    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2570    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2571        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2572            context_server::types::InitializeResponse {
2573                protocol_version: context_server::types::ProtocolVersion(
2574                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2575                ),
2576                server_info: context_server::types::Implementation {
2577                    name: name.into(),
2578                    version: "1.0.0".to_string(),
2579                },
2580                capabilities: context_server::types::ServerCapabilities {
2581                    tools: Some(context_server::types::ToolsCapabilities {
2582                        list_changed: Some(true),
2583                    }),
2584                    ..Default::default()
2585                },
2586                meta: None,
2587            }
2588        })
2589        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2590            let tools = tools.clone();
2591            async move {
2592                context_server::types::ListToolsResponse {
2593                    tools,
2594                    next_cursor: None,
2595                    meta: None,
2596                }
2597            }
2598        })
2599        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2600            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2601            async move {
2602                let (response_tx, response_rx) = oneshot::channel();
2603                mcp_tool_calls_tx
2604                    .unbounded_send((params, response_tx))
2605                    .unwrap();
2606                response_rx.await.unwrap()
2607            }
2608        });
2609    context_server_store.update(cx, |store, cx| {
2610        store.start_server(
2611            Arc::new(ContextServer::new(
2612                ContextServerId(name.into()),
2613                Arc::new(fake_transport),
2614            )),
2615            cx,
2616        );
2617    });
2618    cx.run_until_parked();
2619    mcp_tool_calls_rx
2620}