mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use fs::{FakeFs, Fs};
  10use futures::{
  11    StreamExt,
  12    channel::{
  13        mpsc::{self, UnboundedReceiver},
  14        oneshot,
  15    },
  16};
  17use gpui::{
  18    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  19};
  20use indoc::indoc;
  21use language_model::{
  22    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  23    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  25    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  26};
  27use pretty_assertions::assert_eq;
  28use project::{
  29    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  30};
  31use prompt_store::ProjectContext;
  32use reqwest_client::ReqwestClient;
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use serde_json::json;
  36use settings::{Settings, SettingsStore};
  37use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  38use util::path;
  39
  40mod test_tools;
  41use test_tools::*;
  42
  43#[gpui::test]
  44async fn test_echo(cx: &mut TestAppContext) {
  45    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  46    let fake_model = model.as_fake();
  47
  48    let events = thread
  49        .update(cx, |thread, cx| {
  50            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  51        })
  52        .unwrap();
  53    cx.run_until_parked();
  54    fake_model.send_last_completion_stream_text_chunk("Hello");
  55    fake_model
  56        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  57    fake_model.end_last_completion_stream();
  58
  59    let events = events.collect().await;
  60    thread.update(cx, |thread, _cx| {
  61        assert_eq!(
  62            thread.last_message().unwrap().to_markdown(),
  63            indoc! {"
  64                ## Assistant
  65
  66                Hello
  67            "}
  68        )
  69    });
  70    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  71}
  72
  73#[gpui::test]
  74async fn test_thinking(cx: &mut TestAppContext) {
  75    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  76    let fake_model = model.as_fake();
  77
  78    let events = thread
  79        .update(cx, |thread, cx| {
  80            thread.send(
  81                UserMessageId::new(),
  82                [indoc! {"
  83                    Testing:
  84
  85                    Generate a thinking step where you just think the word 'Think',
  86                    and have your final answer be 'Hello'
  87                "}],
  88                cx,
  89            )
  90        })
  91        .unwrap();
  92    cx.run_until_parked();
  93    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  94        text: "Think".to_string(),
  95        signature: None,
  96    });
  97    fake_model.send_last_completion_stream_text_chunk("Hello");
  98    fake_model
  99        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 100    fake_model.end_last_completion_stream();
 101
 102    let events = events.collect().await;
 103    thread.update(cx, |thread, _cx| {
 104        assert_eq!(
 105            thread.last_message().unwrap().to_markdown(),
 106            indoc! {"
 107                ## Assistant
 108
 109                <think>Think</think>
 110                Hello
 111            "}
 112        )
 113    });
 114    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 115}
 116
 117#[gpui::test]
 118async fn test_system_prompt(cx: &mut TestAppContext) {
 119    let ThreadTest {
 120        model,
 121        thread,
 122        project_context,
 123        ..
 124    } = setup(cx, TestModel::Fake).await;
 125    let fake_model = model.as_fake();
 126
 127    project_context.update(cx, |project_context, _cx| {
 128        project_context.shell = "test-shell".into()
 129    });
 130    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 131    thread
 132        .update(cx, |thread, cx| {
 133            thread.send(UserMessageId::new(), ["abc"], cx)
 134        })
 135        .unwrap();
 136    cx.run_until_parked();
 137    let mut pending_completions = fake_model.pending_completions();
 138    assert_eq!(
 139        pending_completions.len(),
 140        1,
 141        "unexpected pending completions: {:?}",
 142        pending_completions
 143    );
 144
 145    let pending_completion = pending_completions.pop().unwrap();
 146    assert_eq!(pending_completion.messages[0].role, Role::System);
 147
 148    let system_message = &pending_completion.messages[0];
 149    let system_prompt = system_message.content[0].to_str().unwrap();
 150    assert!(
 151        system_prompt.contains("test-shell"),
 152        "unexpected system message: {:?}",
 153        system_message
 154    );
 155    assert!(
 156        system_prompt.contains("## Fixing Diagnostics"),
 157        "unexpected system message: {:?}",
 158        system_message
 159    );
 160}
 161
 162#[gpui::test]
 163async fn test_prompt_caching(cx: &mut TestAppContext) {
 164    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 165    let fake_model = model.as_fake();
 166
 167    // Send initial user message and verify it's cached
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["Message 1"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174
 175    let completion = fake_model.pending_completions().pop().unwrap();
 176    assert_eq!(
 177        completion.messages[1..],
 178        vec![LanguageModelRequestMessage {
 179            role: Role::User,
 180            content: vec!["Message 1".into()],
 181            cache: true
 182        }]
 183    );
 184    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 185        "Response to Message 1".into(),
 186    ));
 187    fake_model.end_last_completion_stream();
 188    cx.run_until_parked();
 189
 190    // Send another user message and verify only the latest is cached
 191    thread
 192        .update(cx, |thread, cx| {
 193            thread.send(UserMessageId::new(), ["Message 2"], cx)
 194        })
 195        .unwrap();
 196    cx.run_until_parked();
 197
 198    let completion = fake_model.pending_completions().pop().unwrap();
 199    assert_eq!(
 200        completion.messages[1..],
 201        vec![
 202            LanguageModelRequestMessage {
 203                role: Role::User,
 204                content: vec!["Message 1".into()],
 205                cache: false
 206            },
 207            LanguageModelRequestMessage {
 208                role: Role::Assistant,
 209                content: vec!["Response to Message 1".into()],
 210                cache: false
 211            },
 212            LanguageModelRequestMessage {
 213                role: Role::User,
 214                content: vec!["Message 2".into()],
 215                cache: true
 216            }
 217        ]
 218    );
 219    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 220        "Response to Message 2".into(),
 221    ));
 222    fake_model.end_last_completion_stream();
 223    cx.run_until_parked();
 224
 225    // Simulate a tool call and verify that the latest tool result is cached
 226    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 227    thread
 228        .update(cx, |thread, cx| {
 229            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 230        })
 231        .unwrap();
 232    cx.run_until_parked();
 233
 234    let tool_use = LanguageModelToolUse {
 235        id: "tool_1".into(),
 236        name: EchoTool::name().into(),
 237        raw_input: json!({"text": "test"}).to_string(),
 238        input: json!({"text": "test"}),
 239        is_input_complete: true,
 240    };
 241    fake_model
 242        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 243    fake_model.end_last_completion_stream();
 244    cx.run_until_parked();
 245
 246    let completion = fake_model.pending_completions().pop().unwrap();
 247    let tool_result = LanguageModelToolResult {
 248        tool_use_id: "tool_1".into(),
 249        tool_name: EchoTool::name().into(),
 250        is_error: false,
 251        content: "test".into(),
 252        output: Some("test".into()),
 253    };
 254    assert_eq!(
 255        completion.messages[1..],
 256        vec![
 257            LanguageModelRequestMessage {
 258                role: Role::User,
 259                content: vec!["Message 1".into()],
 260                cache: false
 261            },
 262            LanguageModelRequestMessage {
 263                role: Role::Assistant,
 264                content: vec!["Response to Message 1".into()],
 265                cache: false
 266            },
 267            LanguageModelRequestMessage {
 268                role: Role::User,
 269                content: vec!["Message 2".into()],
 270                cache: false
 271            },
 272            LanguageModelRequestMessage {
 273                role: Role::Assistant,
 274                content: vec!["Response to Message 2".into()],
 275                cache: false
 276            },
 277            LanguageModelRequestMessage {
 278                role: Role::User,
 279                content: vec!["Use the echo tool".into()],
 280                cache: false
 281            },
 282            LanguageModelRequestMessage {
 283                role: Role::Assistant,
 284                content: vec![MessageContent::ToolUse(tool_use)],
 285                cache: false
 286            },
 287            LanguageModelRequestMessage {
 288                role: Role::User,
 289                content: vec![MessageContent::ToolResult(tool_result)],
 290                cache: true
 291            }
 292        ]
 293    );
 294}
 295
 296#[gpui::test]
 297#[cfg_attr(not(feature = "e2e"), ignore)]
 298async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 299    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 300
 301    // Test a tool call that's likely to complete *before* streaming stops.
 302    let events = thread
 303        .update(cx, |thread, cx| {
 304            thread.add_tool(EchoTool);
 305            thread.send(
 306                UserMessageId::new(),
 307                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 308                cx,
 309            )
 310        })
 311        .unwrap()
 312        .collect()
 313        .await;
 314    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 315
 316    // Test a tool calls that's likely to complete *after* streaming stops.
 317    let events = thread
 318        .update(cx, |thread, cx| {
 319            thread.remove_tool(&EchoTool::name());
 320            thread.add_tool(DelayTool);
 321            thread.send(
 322                UserMessageId::new(),
 323                [
 324                    "Now call the delay tool with 200ms.",
 325                    "When the timer goes off, then you echo the output of the tool.",
 326                ],
 327                cx,
 328            )
 329        })
 330        .unwrap()
 331        .collect()
 332        .await;
 333    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 334    thread.update(cx, |thread, _cx| {
 335        assert!(
 336            thread
 337                .last_message()
 338                .unwrap()
 339                .as_agent_message()
 340                .unwrap()
 341                .content
 342                .iter()
 343                .any(|content| {
 344                    if let AgentMessageContent::Text(text) = content {
 345                        text.contains("Ding")
 346                    } else {
 347                        false
 348                    }
 349                }),
 350            "{}",
 351            thread.to_markdown()
 352        );
 353    });
 354}
 355
 356#[gpui::test]
 357#[cfg_attr(not(feature = "e2e"), ignore)]
 358async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 359    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 360
 361    // Test a tool call that's likely to complete *before* streaming stops.
 362    let mut events = thread
 363        .update(cx, |thread, cx| {
 364            thread.add_tool(WordListTool);
 365            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 366        })
 367        .unwrap();
 368
 369    let mut saw_partial_tool_use = false;
 370    while let Some(event) = events.next().await {
 371        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 372            thread.update(cx, |thread, _cx| {
 373                // Look for a tool use in the thread's last message
 374                let message = thread.last_message().unwrap();
 375                let agent_message = message.as_agent_message().unwrap();
 376                let last_content = agent_message.content.last().unwrap();
 377                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 378                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 379                    if tool_call.status == acp::ToolCallStatus::Pending {
 380                        if !last_tool_use.is_input_complete
 381                            && last_tool_use.input.get("g").is_none()
 382                        {
 383                            saw_partial_tool_use = true;
 384                        }
 385                    } else {
 386                        last_tool_use
 387                            .input
 388                            .get("a")
 389                            .expect("'a' has streamed because input is now complete");
 390                        last_tool_use
 391                            .input
 392                            .get("g")
 393                            .expect("'g' has streamed because input is now complete");
 394                    }
 395                } else {
 396                    panic!("last content should be a tool use");
 397                }
 398            });
 399        }
 400    }
 401
 402    assert!(
 403        saw_partial_tool_use,
 404        "should see at least one partially streamed tool use in the history"
 405    );
 406}
 407
 408#[gpui::test]
 409async fn test_tool_authorization(cx: &mut TestAppContext) {
 410    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 411    let fake_model = model.as_fake();
 412
 413    let mut events = thread
 414        .update(cx, |thread, cx| {
 415            thread.add_tool(ToolRequiringPermission);
 416            thread.send(UserMessageId::new(), ["abc"], cx)
 417        })
 418        .unwrap();
 419    cx.run_until_parked();
 420    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 421        LanguageModelToolUse {
 422            id: "tool_id_1".into(),
 423            name: ToolRequiringPermission::name().into(),
 424            raw_input: "{}".into(),
 425            input: json!({}),
 426            is_input_complete: true,
 427        },
 428    ));
 429    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 430        LanguageModelToolUse {
 431            id: "tool_id_2".into(),
 432            name: ToolRequiringPermission::name().into(),
 433            raw_input: "{}".into(),
 434            input: json!({}),
 435            is_input_complete: true,
 436        },
 437    ));
 438    fake_model.end_last_completion_stream();
 439    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 440    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 441
 442    // Approve the first
 443    tool_call_auth_1
 444        .response
 445        .send(tool_call_auth_1.options[1].id.clone())
 446        .unwrap();
 447    cx.run_until_parked();
 448
 449    // Reject the second
 450    tool_call_auth_2
 451        .response
 452        .send(tool_call_auth_1.options[2].id.clone())
 453        .unwrap();
 454    cx.run_until_parked();
 455
 456    let completion = fake_model.pending_completions().pop().unwrap();
 457    let message = completion.messages.last().unwrap();
 458    assert_eq!(
 459        message.content,
 460        vec![
 461            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 462                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 463                tool_name: ToolRequiringPermission::name().into(),
 464                is_error: false,
 465                content: "Allowed".into(),
 466                output: Some("Allowed".into())
 467            }),
 468            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 469                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 470                tool_name: ToolRequiringPermission::name().into(),
 471                is_error: true,
 472                content: "Permission to run tool denied by user".into(),
 473                output: None
 474            })
 475        ]
 476    );
 477
 478    // Simulate yet another tool call.
 479    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 480        LanguageModelToolUse {
 481            id: "tool_id_3".into(),
 482            name: ToolRequiringPermission::name().into(),
 483            raw_input: "{}".into(),
 484            input: json!({}),
 485            is_input_complete: true,
 486        },
 487    ));
 488    fake_model.end_last_completion_stream();
 489
 490    // Respond by always allowing tools.
 491    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 492    tool_call_auth_3
 493        .response
 494        .send(tool_call_auth_3.options[0].id.clone())
 495        .unwrap();
 496    cx.run_until_parked();
 497    let completion = fake_model.pending_completions().pop().unwrap();
 498    let message = completion.messages.last().unwrap();
 499    assert_eq!(
 500        message.content,
 501        vec![language_model::MessageContent::ToolResult(
 502            LanguageModelToolResult {
 503                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 504                tool_name: ToolRequiringPermission::name().into(),
 505                is_error: false,
 506                content: "Allowed".into(),
 507                output: Some("Allowed".into())
 508            }
 509        )]
 510    );
 511
 512    // Simulate a final tool call, ensuring we don't trigger authorization.
 513    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 514        LanguageModelToolUse {
 515            id: "tool_id_4".into(),
 516            name: ToolRequiringPermission::name().into(),
 517            raw_input: "{}".into(),
 518            input: json!({}),
 519            is_input_complete: true,
 520        },
 521    ));
 522    fake_model.end_last_completion_stream();
 523    cx.run_until_parked();
 524    let completion = fake_model.pending_completions().pop().unwrap();
 525    let message = completion.messages.last().unwrap();
 526    assert_eq!(
 527        message.content,
 528        vec![language_model::MessageContent::ToolResult(
 529            LanguageModelToolResult {
 530                tool_use_id: "tool_id_4".into(),
 531                tool_name: ToolRequiringPermission::name().into(),
 532                is_error: false,
 533                content: "Allowed".into(),
 534                output: Some("Allowed".into())
 535            }
 536        )]
 537    );
 538}
 539
 540#[gpui::test]
 541async fn test_tool_hallucination(cx: &mut TestAppContext) {
 542    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 543    let fake_model = model.as_fake();
 544
 545    let mut events = thread
 546        .update(cx, |thread, cx| {
 547            thread.send(UserMessageId::new(), ["abc"], cx)
 548        })
 549        .unwrap();
 550    cx.run_until_parked();
 551    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 552        LanguageModelToolUse {
 553            id: "tool_id_1".into(),
 554            name: "nonexistent_tool".into(),
 555            raw_input: "{}".into(),
 556            input: json!({}),
 557            is_input_complete: true,
 558        },
 559    ));
 560    fake_model.end_last_completion_stream();
 561
 562    let tool_call = expect_tool_call(&mut events).await;
 563    assert_eq!(tool_call.title, "nonexistent_tool");
 564    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 565    let update = expect_tool_call_update_fields(&mut events).await;
 566    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 567}
 568
 569#[gpui::test]
 570async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 571    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 572    let fake_model = model.as_fake();
 573
 574    let events = thread
 575        .update(cx, |thread, cx| {
 576            thread.add_tool(EchoTool);
 577            thread.send(UserMessageId::new(), ["abc"], cx)
 578        })
 579        .unwrap();
 580    cx.run_until_parked();
 581    let tool_use = LanguageModelToolUse {
 582        id: "tool_id_1".into(),
 583        name: EchoTool::name().into(),
 584        raw_input: "{}".into(),
 585        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 586        is_input_complete: true,
 587    };
 588    fake_model
 589        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 590    fake_model.end_last_completion_stream();
 591
 592    cx.run_until_parked();
 593    let completion = fake_model.pending_completions().pop().unwrap();
 594    let tool_result = LanguageModelToolResult {
 595        tool_use_id: "tool_id_1".into(),
 596        tool_name: EchoTool::name().into(),
 597        is_error: false,
 598        content: "def".into(),
 599        output: Some("def".into()),
 600    };
 601    assert_eq!(
 602        completion.messages[1..],
 603        vec![
 604            LanguageModelRequestMessage {
 605                role: Role::User,
 606                content: vec!["abc".into()],
 607                cache: false
 608            },
 609            LanguageModelRequestMessage {
 610                role: Role::Assistant,
 611                content: vec![MessageContent::ToolUse(tool_use.clone())],
 612                cache: false
 613            },
 614            LanguageModelRequestMessage {
 615                role: Role::User,
 616                content: vec![MessageContent::ToolResult(tool_result.clone())],
 617                cache: true
 618            },
 619        ]
 620    );
 621
 622    // Simulate reaching tool use limit.
 623    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 624        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 625    ));
 626    fake_model.end_last_completion_stream();
 627    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 628    assert!(
 629        last_event
 630            .unwrap_err()
 631            .is::<language_model::ToolUseLimitReachedError>()
 632    );
 633
 634    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 635    cx.run_until_parked();
 636    let completion = fake_model.pending_completions().pop().unwrap();
 637    assert_eq!(
 638        completion.messages[1..],
 639        vec![
 640            LanguageModelRequestMessage {
 641                role: Role::User,
 642                content: vec!["abc".into()],
 643                cache: false
 644            },
 645            LanguageModelRequestMessage {
 646                role: Role::Assistant,
 647                content: vec![MessageContent::ToolUse(tool_use)],
 648                cache: false
 649            },
 650            LanguageModelRequestMessage {
 651                role: Role::User,
 652                content: vec![MessageContent::ToolResult(tool_result)],
 653                cache: false
 654            },
 655            LanguageModelRequestMessage {
 656                role: Role::User,
 657                content: vec!["Continue where you left off".into()],
 658                cache: true
 659            }
 660        ]
 661    );
 662
 663    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 664    fake_model.end_last_completion_stream();
 665    events.collect::<Vec<_>>().await;
 666    thread.read_with(cx, |thread, _cx| {
 667        assert_eq!(
 668            thread.last_message().unwrap().to_markdown(),
 669            indoc! {"
 670                ## Assistant
 671
 672                Done
 673            "}
 674        )
 675    });
 676
 677    // Ensure we error if calling resume when tool use limit was *not* reached.
 678    let error = thread
 679        .update(cx, |thread, cx| thread.resume(cx))
 680        .unwrap_err();
 681    assert_eq!(
 682        error.to_string(),
 683        "can only resume after tool use limit is reached"
 684    )
 685}
 686
 687#[gpui::test]
 688async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 689    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 690    let fake_model = model.as_fake();
 691
 692    let events = thread
 693        .update(cx, |thread, cx| {
 694            thread.add_tool(EchoTool);
 695            thread.send(UserMessageId::new(), ["abc"], cx)
 696        })
 697        .unwrap();
 698    cx.run_until_parked();
 699
 700    let tool_use = LanguageModelToolUse {
 701        id: "tool_id_1".into(),
 702        name: EchoTool::name().into(),
 703        raw_input: "{}".into(),
 704        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 705        is_input_complete: true,
 706    };
 707    let tool_result = LanguageModelToolResult {
 708        tool_use_id: "tool_id_1".into(),
 709        tool_name: EchoTool::name().into(),
 710        is_error: false,
 711        content: "def".into(),
 712        output: Some("def".into()),
 713    };
 714    fake_model
 715        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 716    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 717        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 718    ));
 719    fake_model.end_last_completion_stream();
 720    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 721    assert!(
 722        last_event
 723            .unwrap_err()
 724            .is::<language_model::ToolUseLimitReachedError>()
 725    );
 726
 727    thread
 728        .update(cx, |thread, cx| {
 729            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 730        })
 731        .unwrap();
 732    cx.run_until_parked();
 733    let completion = fake_model.pending_completions().pop().unwrap();
 734    assert_eq!(
 735        completion.messages[1..],
 736        vec![
 737            LanguageModelRequestMessage {
 738                role: Role::User,
 739                content: vec!["abc".into()],
 740                cache: false
 741            },
 742            LanguageModelRequestMessage {
 743                role: Role::Assistant,
 744                content: vec![MessageContent::ToolUse(tool_use)],
 745                cache: false
 746            },
 747            LanguageModelRequestMessage {
 748                role: Role::User,
 749                content: vec![MessageContent::ToolResult(tool_result)],
 750                cache: false
 751            },
 752            LanguageModelRequestMessage {
 753                role: Role::User,
 754                content: vec!["ghi".into()],
 755                cache: true
 756            }
 757        ]
 758    );
 759}
 760
 761async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 762    let event = events
 763        .next()
 764        .await
 765        .expect("no tool call authorization event received")
 766        .unwrap();
 767    match event {
 768        ThreadEvent::ToolCall(tool_call) => tool_call,
 769        event => {
 770            panic!("Unexpected event {event:?}");
 771        }
 772    }
 773}
 774
 775async fn expect_tool_call_update_fields(
 776    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 777) -> acp::ToolCallUpdate {
 778    let event = events
 779        .next()
 780        .await
 781        .expect("no tool call authorization event received")
 782        .unwrap();
 783    match event {
 784        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 785        event => {
 786            panic!("Unexpected event {event:?}");
 787        }
 788    }
 789}
 790
 791async fn next_tool_call_authorization(
 792    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 793) -> ToolCallAuthorization {
 794    loop {
 795        let event = events
 796            .next()
 797            .await
 798            .expect("no tool call authorization event received")
 799            .unwrap();
 800        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 801            let permission_kinds = tool_call_authorization
 802                .options
 803                .iter()
 804                .map(|o| o.kind)
 805                .collect::<Vec<_>>();
 806            assert_eq!(
 807                permission_kinds,
 808                vec![
 809                    acp::PermissionOptionKind::AllowAlways,
 810                    acp::PermissionOptionKind::AllowOnce,
 811                    acp::PermissionOptionKind::RejectOnce,
 812                ]
 813            );
 814            return tool_call_authorization;
 815        }
 816    }
 817}
 818
 819#[gpui::test]
 820#[cfg_attr(not(feature = "e2e"), ignore)]
 821async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 822    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 823
 824    // Test concurrent tool calls with different delay times
 825    let events = thread
 826        .update(cx, |thread, cx| {
 827            thread.add_tool(DelayTool);
 828            thread.send(
 829                UserMessageId::new(),
 830                [
 831                    "Call the delay tool twice in the same message.",
 832                    "Once with 100ms. Once with 300ms.",
 833                    "When both timers are complete, describe the outputs.",
 834                ],
 835                cx,
 836            )
 837        })
 838        .unwrap()
 839        .collect()
 840        .await;
 841
 842    let stop_reasons = stop_events(events);
 843    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 844
 845    thread.update(cx, |thread, _cx| {
 846        let last_message = thread.last_message().unwrap();
 847        let agent_message = last_message.as_agent_message().unwrap();
 848        let text = agent_message
 849            .content
 850            .iter()
 851            .filter_map(|content| {
 852                if let AgentMessageContent::Text(text) = content {
 853                    Some(text.as_str())
 854                } else {
 855                    None
 856                }
 857            })
 858            .collect::<String>();
 859
 860        assert!(text.contains("Ding"));
 861    });
 862}
 863
 864#[gpui::test]
 865async fn test_profiles(cx: &mut TestAppContext) {
 866    let ThreadTest {
 867        model, thread, fs, ..
 868    } = setup(cx, TestModel::Fake).await;
 869    let fake_model = model.as_fake();
 870
 871    thread.update(cx, |thread, _cx| {
 872        thread.add_tool(DelayTool);
 873        thread.add_tool(EchoTool);
 874        thread.add_tool(InfiniteTool);
 875    });
 876
 877    // Override profiles and wait for settings to be loaded.
 878    fs.insert_file(
 879        paths::settings_file(),
 880        json!({
 881            "agent": {
 882                "profiles": {
 883                    "test-1": {
 884                        "name": "Test Profile 1",
 885                        "tools": {
 886                            EchoTool::name(): true,
 887                            DelayTool::name(): true,
 888                        }
 889                    },
 890                    "test-2": {
 891                        "name": "Test Profile 2",
 892                        "tools": {
 893                            InfiniteTool::name(): true,
 894                        }
 895                    }
 896                }
 897            }
 898        })
 899        .to_string()
 900        .into_bytes(),
 901    )
 902    .await;
 903    cx.run_until_parked();
 904
 905    // Test that test-1 profile (default) has echo and delay tools
 906    thread
 907        .update(cx, |thread, cx| {
 908            thread.set_profile(AgentProfileId("test-1".into()));
 909            thread.send(UserMessageId::new(), ["test"], cx)
 910        })
 911        .unwrap();
 912    cx.run_until_parked();
 913
 914    let mut pending_completions = fake_model.pending_completions();
 915    assert_eq!(pending_completions.len(), 1);
 916    let completion = pending_completions.pop().unwrap();
 917    let tool_names: Vec<String> = completion
 918        .tools
 919        .iter()
 920        .map(|tool| tool.name.clone())
 921        .collect();
 922    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 923    fake_model.end_last_completion_stream();
 924
 925    // Switch to test-2 profile, and verify that it has only the infinite tool.
 926    thread
 927        .update(cx, |thread, cx| {
 928            thread.set_profile(AgentProfileId("test-2".into()));
 929            thread.send(UserMessageId::new(), ["test2"], cx)
 930        })
 931        .unwrap();
 932    cx.run_until_parked();
 933    let mut pending_completions = fake_model.pending_completions();
 934    assert_eq!(pending_completions.len(), 1);
 935    let completion = pending_completions.pop().unwrap();
 936    let tool_names: Vec<String> = completion
 937        .tools
 938        .iter()
 939        .map(|tool| tool.name.clone())
 940        .collect();
 941    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 942}
 943
 944#[gpui::test]
 945async fn test_mcp_tools(cx: &mut TestAppContext) {
 946    let ThreadTest {
 947        model,
 948        thread,
 949        context_server_store,
 950        fs,
 951        ..
 952    } = setup(cx, TestModel::Fake).await;
 953    let fake_model = model.as_fake();
 954
 955    // Override profiles and wait for settings to be loaded.
 956    fs.insert_file(
 957        paths::settings_file(),
 958        json!({
 959            "agent": {
 960                "profiles": {
 961                    "test": {
 962                        "name": "Test Profile",
 963                        "enable_all_context_servers": true,
 964                        "tools": {
 965                            EchoTool::name(): true,
 966                        }
 967                    },
 968                }
 969            }
 970        })
 971        .to_string()
 972        .into_bytes(),
 973    )
 974    .await;
 975    cx.run_until_parked();
 976    thread.update(cx, |thread, _| {
 977        thread.set_profile(AgentProfileId("test".into()))
 978    });
 979
 980    let mut mcp_tool_calls = setup_context_server(
 981        "test_server",
 982        vec![context_server::types::Tool {
 983            name: "echo".into(),
 984            description: None,
 985            input_schema: serde_json::to_value(
 986                EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
 987            )
 988            .unwrap(),
 989            output_schema: None,
 990            annotations: None,
 991        }],
 992        &context_server_store,
 993        cx,
 994    );
 995
 996    let events = thread.update(cx, |thread, cx| {
 997        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
 998    });
 999    cx.run_until_parked();
1000
1001    // Simulate the model calling the MCP tool.
1002    let completion = fake_model.pending_completions().pop().unwrap();
1003    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1004    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1005        LanguageModelToolUse {
1006            id: "tool_1".into(),
1007            name: "echo".into(),
1008            raw_input: json!({"text": "test"}).to_string(),
1009            input: json!({"text": "test"}),
1010            is_input_complete: true,
1011        },
1012    ));
1013    fake_model.end_last_completion_stream();
1014    cx.run_until_parked();
1015
1016    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1017    assert_eq!(tool_call_params.name, "echo");
1018    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1019    tool_call_response
1020        .send(context_server::types::CallToolResponse {
1021            content: vec![context_server::types::ToolResponseContent::Text {
1022                text: "test".into(),
1023            }],
1024            is_error: None,
1025            meta: None,
1026            structured_content: None,
1027        })
1028        .unwrap();
1029    cx.run_until_parked();
1030
1031    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1032    fake_model.send_last_completion_stream_text_chunk("Done!");
1033    fake_model.end_last_completion_stream();
1034    events.collect::<Vec<_>>().await;
1035
1036    // Send again after adding the echo tool, ensuring the name collision is resolved.
1037    let events = thread.update(cx, |thread, cx| {
1038        thread.add_tool(EchoTool);
1039        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1040    });
1041    cx.run_until_parked();
1042    let completion = fake_model.pending_completions().pop().unwrap();
1043    assert_eq!(
1044        tool_names_for_completion(&completion),
1045        vec!["echo", "test_server_echo"]
1046    );
1047    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1048        LanguageModelToolUse {
1049            id: "tool_2".into(),
1050            name: "test_server_echo".into(),
1051            raw_input: json!({"text": "mcp"}).to_string(),
1052            input: json!({"text": "mcp"}),
1053            is_input_complete: true,
1054        },
1055    ));
1056    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1057        LanguageModelToolUse {
1058            id: "tool_3".into(),
1059            name: "echo".into(),
1060            raw_input: json!({"text": "native"}).to_string(),
1061            input: json!({"text": "native"}),
1062            is_input_complete: true,
1063        },
1064    ));
1065    fake_model.end_last_completion_stream();
1066    cx.run_until_parked();
1067
1068    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1069    assert_eq!(tool_call_params.name, "echo");
1070    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1071    tool_call_response
1072        .send(context_server::types::CallToolResponse {
1073            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1074            is_error: None,
1075            meta: None,
1076            structured_content: None,
1077        })
1078        .unwrap();
1079    cx.run_until_parked();
1080
1081    // Ensure the tool results were inserted with the correct names.
1082    let completion = fake_model.pending_completions().pop().unwrap();
1083    assert_eq!(
1084        completion.messages.last().unwrap().content,
1085        vec![
1086            MessageContent::ToolResult(LanguageModelToolResult {
1087                tool_use_id: "tool_3".into(),
1088                tool_name: "echo".into(),
1089                is_error: false,
1090                content: "native".into(),
1091                output: Some("native".into()),
1092            },),
1093            MessageContent::ToolResult(LanguageModelToolResult {
1094                tool_use_id: "tool_2".into(),
1095                tool_name: "test_server_echo".into(),
1096                is_error: false,
1097                content: "mcp".into(),
1098                output: Some("mcp".into()),
1099            },),
1100        ]
1101    );
1102    fake_model.end_last_completion_stream();
1103    events.collect::<Vec<_>>().await;
1104}
1105
1106#[gpui::test]
1107async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1108    let ThreadTest {
1109        model,
1110        thread,
1111        context_server_store,
1112        fs,
1113        ..
1114    } = setup(cx, TestModel::Fake).await;
1115    let fake_model = model.as_fake();
1116
1117    // Set up a profile with all tools enabled
1118    fs.insert_file(
1119        paths::settings_file(),
1120        json!({
1121            "agent": {
1122                "profiles": {
1123                    "test": {
1124                        "name": "Test Profile",
1125                        "enable_all_context_servers": true,
1126                        "tools": {
1127                            EchoTool::name(): true,
1128                            DelayTool::name(): true,
1129                            WordListTool::name(): true,
1130                            ToolRequiringPermission::name(): true,
1131                            InfiniteTool::name(): true,
1132                        }
1133                    },
1134                }
1135            }
1136        })
1137        .to_string()
1138        .into_bytes(),
1139    )
1140    .await;
1141    cx.run_until_parked();
1142
1143    thread.update(cx, |thread, _| {
1144        thread.set_profile(AgentProfileId("test".into()));
1145        thread.add_tool(EchoTool);
1146        thread.add_tool(DelayTool);
1147        thread.add_tool(WordListTool);
1148        thread.add_tool(ToolRequiringPermission);
1149        thread.add_tool(InfiniteTool);
1150    });
1151
1152    // Set up multiple context servers with some overlapping tool names
1153    let _server1_calls = setup_context_server(
1154        "xxx",
1155        vec![
1156            context_server::types::Tool {
1157                name: "echo".into(), // Conflicts with native EchoTool
1158                description: None,
1159                input_schema: serde_json::to_value(
1160                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1161                )
1162                .unwrap(),
1163                output_schema: None,
1164                annotations: None,
1165            },
1166            context_server::types::Tool {
1167                name: "unique_tool_1".into(),
1168                description: None,
1169                input_schema: json!({"type": "object", "properties": {}}),
1170                output_schema: None,
1171                annotations: None,
1172            },
1173        ],
1174        &context_server_store,
1175        cx,
1176    );
1177
1178    let _server2_calls = setup_context_server(
1179        "yyy",
1180        vec![
1181            context_server::types::Tool {
1182                name: "echo".into(), // Also conflicts with native EchoTool
1183                description: None,
1184                input_schema: serde_json::to_value(
1185                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1186                )
1187                .unwrap(),
1188                output_schema: None,
1189                annotations: None,
1190            },
1191            context_server::types::Tool {
1192                name: "unique_tool_2".into(),
1193                description: None,
1194                input_schema: json!({"type": "object", "properties": {}}),
1195                output_schema: None,
1196                annotations: None,
1197            },
1198            context_server::types::Tool {
1199                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1200                description: None,
1201                input_schema: json!({"type": "object", "properties": {}}),
1202                output_schema: None,
1203                annotations: None,
1204            },
1205            context_server::types::Tool {
1206                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1207                description: None,
1208                input_schema: json!({"type": "object", "properties": {}}),
1209                output_schema: None,
1210                annotations: None,
1211            },
1212        ],
1213        &context_server_store,
1214        cx,
1215    );
1216    let _server3_calls = setup_context_server(
1217        "zzz",
1218        vec![
1219            context_server::types::Tool {
1220                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1221                description: None,
1222                input_schema: json!({"type": "object", "properties": {}}),
1223                output_schema: None,
1224                annotations: None,
1225            },
1226            context_server::types::Tool {
1227                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1228                description: None,
1229                input_schema: json!({"type": "object", "properties": {}}),
1230                output_schema: None,
1231                annotations: None,
1232            },
1233            context_server::types::Tool {
1234                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1235                description: None,
1236                input_schema: json!({"type": "object", "properties": {}}),
1237                output_schema: None,
1238                annotations: None,
1239            },
1240        ],
1241        &context_server_store,
1242        cx,
1243    );
1244
1245    thread
1246        .update(cx, |thread, cx| {
1247            thread.send(UserMessageId::new(), ["Go"], cx)
1248        })
1249        .unwrap();
1250    cx.run_until_parked();
1251    let completion = fake_model.pending_completions().pop().unwrap();
1252    assert_eq!(
1253        tool_names_for_completion(&completion),
1254        vec![
1255            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1256            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1257            "delay",
1258            "echo",
1259            "infinite",
1260            "tool_requiring_permission",
1261            "unique_tool_1",
1262            "unique_tool_2",
1263            "word_list",
1264            "xxx_echo",
1265            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1266            "yyy_echo",
1267            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1268        ]
1269    );
1270}
1271
1272#[gpui::test]
1273#[cfg_attr(not(feature = "e2e"), ignore)]
1274async fn test_cancellation(cx: &mut TestAppContext) {
1275    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1276
1277    let mut events = thread
1278        .update(cx, |thread, cx| {
1279            thread.add_tool(InfiniteTool);
1280            thread.add_tool(EchoTool);
1281            thread.send(
1282                UserMessageId::new(),
1283                ["Call the echo tool, then call the infinite tool, then explain their output"],
1284                cx,
1285            )
1286        })
1287        .unwrap();
1288
1289    // Wait until both tools are called.
1290    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1291    let mut echo_id = None;
1292    let mut echo_completed = false;
1293    while let Some(event) = events.next().await {
1294        match event.unwrap() {
1295            ThreadEvent::ToolCall(tool_call) => {
1296                assert_eq!(tool_call.title, expected_tools.remove(0));
1297                if tool_call.title == "Echo" {
1298                    echo_id = Some(tool_call.id);
1299                }
1300            }
1301            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1302                acp::ToolCallUpdate {
1303                    id,
1304                    fields:
1305                        acp::ToolCallUpdateFields {
1306                            status: Some(acp::ToolCallStatus::Completed),
1307                            ..
1308                        },
1309                },
1310            )) if Some(&id) == echo_id.as_ref() => {
1311                echo_completed = true;
1312            }
1313            _ => {}
1314        }
1315
1316        if expected_tools.is_empty() && echo_completed {
1317            break;
1318        }
1319    }
1320
1321    // Cancel the current send and ensure that the event stream is closed, even
1322    // if one of the tools is still running.
1323    thread.update(cx, |thread, cx| thread.cancel(cx));
1324    let events = events.collect::<Vec<_>>().await;
1325    let last_event = events.last();
1326    assert!(
1327        matches!(
1328            last_event,
1329            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1330        ),
1331        "unexpected event {last_event:?}"
1332    );
1333
1334    // Ensure we can still send a new message after cancellation.
1335    let events = thread
1336        .update(cx, |thread, cx| {
1337            thread.send(
1338                UserMessageId::new(),
1339                ["Testing: reply with 'Hello' then stop."],
1340                cx,
1341            )
1342        })
1343        .unwrap()
1344        .collect::<Vec<_>>()
1345        .await;
1346    thread.update(cx, |thread, _cx| {
1347        let message = thread.last_message().unwrap();
1348        let agent_message = message.as_agent_message().unwrap();
1349        assert_eq!(
1350            agent_message.content,
1351            vec![AgentMessageContent::Text("Hello".to_string())]
1352        );
1353    });
1354    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1355}
1356
1357#[gpui::test]
1358async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1359    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1360    let fake_model = model.as_fake();
1361
1362    let events_1 = thread
1363        .update(cx, |thread, cx| {
1364            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1365        })
1366        .unwrap();
1367    cx.run_until_parked();
1368    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1369    cx.run_until_parked();
1370
1371    let events_2 = thread
1372        .update(cx, |thread, cx| {
1373            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1374        })
1375        .unwrap();
1376    cx.run_until_parked();
1377    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1378    fake_model
1379        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1380    fake_model.end_last_completion_stream();
1381
1382    let events_1 = events_1.collect::<Vec<_>>().await;
1383    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1384    let events_2 = events_2.collect::<Vec<_>>().await;
1385    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1386}
1387
1388#[gpui::test]
1389async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1390    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1391    let fake_model = model.as_fake();
1392
1393    let events_1 = thread
1394        .update(cx, |thread, cx| {
1395            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1396        })
1397        .unwrap();
1398    cx.run_until_parked();
1399    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1400    fake_model
1401        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1402    fake_model.end_last_completion_stream();
1403    let events_1 = events_1.collect::<Vec<_>>().await;
1404
1405    let events_2 = thread
1406        .update(cx, |thread, cx| {
1407            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1408        })
1409        .unwrap();
1410    cx.run_until_parked();
1411    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1412    fake_model
1413        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1414    fake_model.end_last_completion_stream();
1415    let events_2 = events_2.collect::<Vec<_>>().await;
1416
1417    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1418    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1419}
1420
1421#[gpui::test]
1422async fn test_refusal(cx: &mut TestAppContext) {
1423    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1424    let fake_model = model.as_fake();
1425
1426    let events = thread
1427        .update(cx, |thread, cx| {
1428            thread.send(UserMessageId::new(), ["Hello"], cx)
1429        })
1430        .unwrap();
1431    cx.run_until_parked();
1432    thread.read_with(cx, |thread, _| {
1433        assert_eq!(
1434            thread.to_markdown(),
1435            indoc! {"
1436                ## User
1437
1438                Hello
1439            "}
1440        );
1441    });
1442
1443    fake_model.send_last_completion_stream_text_chunk("Hey!");
1444    cx.run_until_parked();
1445    thread.read_with(cx, |thread, _| {
1446        assert_eq!(
1447            thread.to_markdown(),
1448            indoc! {"
1449                ## User
1450
1451                Hello
1452
1453                ## Assistant
1454
1455                Hey!
1456            "}
1457        );
1458    });
1459
1460    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1461    fake_model
1462        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1463    let events = events.collect::<Vec<_>>().await;
1464    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1465    thread.read_with(cx, |thread, _| {
1466        assert_eq!(thread.to_markdown(), "");
1467    });
1468}
1469
1470#[gpui::test]
1471async fn test_truncate_first_message(cx: &mut TestAppContext) {
1472    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1473    let fake_model = model.as_fake();
1474
1475    let message_id = UserMessageId::new();
1476    thread
1477        .update(cx, |thread, cx| {
1478            thread.send(message_id.clone(), ["Hello"], cx)
1479        })
1480        .unwrap();
1481    cx.run_until_parked();
1482    thread.read_with(cx, |thread, _| {
1483        assert_eq!(
1484            thread.to_markdown(),
1485            indoc! {"
1486                ## User
1487
1488                Hello
1489            "}
1490        );
1491        assert_eq!(thread.latest_token_usage(), None);
1492    });
1493
1494    fake_model.send_last_completion_stream_text_chunk("Hey!");
1495    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1496        language_model::TokenUsage {
1497            input_tokens: 32_000,
1498            output_tokens: 16_000,
1499            cache_creation_input_tokens: 0,
1500            cache_read_input_tokens: 0,
1501        },
1502    ));
1503    cx.run_until_parked();
1504    thread.read_with(cx, |thread, _| {
1505        assert_eq!(
1506            thread.to_markdown(),
1507            indoc! {"
1508                ## User
1509
1510                Hello
1511
1512                ## Assistant
1513
1514                Hey!
1515            "}
1516        );
1517        assert_eq!(
1518            thread.latest_token_usage(),
1519            Some(acp_thread::TokenUsage {
1520                used_tokens: 32_000 + 16_000,
1521                max_tokens: 1_000_000,
1522            })
1523        );
1524    });
1525
1526    thread
1527        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1528        .unwrap();
1529    cx.run_until_parked();
1530    thread.read_with(cx, |thread, _| {
1531        assert_eq!(thread.to_markdown(), "");
1532        assert_eq!(thread.latest_token_usage(), None);
1533    });
1534
1535    // Ensure we can still send a new message after truncation.
1536    thread
1537        .update(cx, |thread, cx| {
1538            thread.send(UserMessageId::new(), ["Hi"], cx)
1539        })
1540        .unwrap();
1541    thread.update(cx, |thread, _cx| {
1542        assert_eq!(
1543            thread.to_markdown(),
1544            indoc! {"
1545                ## User
1546
1547                Hi
1548            "}
1549        );
1550    });
1551    cx.run_until_parked();
1552    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1553    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1554        language_model::TokenUsage {
1555            input_tokens: 40_000,
1556            output_tokens: 20_000,
1557            cache_creation_input_tokens: 0,
1558            cache_read_input_tokens: 0,
1559        },
1560    ));
1561    cx.run_until_parked();
1562    thread.read_with(cx, |thread, _| {
1563        assert_eq!(
1564            thread.to_markdown(),
1565            indoc! {"
1566                ## User
1567
1568                Hi
1569
1570                ## Assistant
1571
1572                Ahoy!
1573            "}
1574        );
1575
1576        assert_eq!(
1577            thread.latest_token_usage(),
1578            Some(acp_thread::TokenUsage {
1579                used_tokens: 40_000 + 20_000,
1580                max_tokens: 1_000_000,
1581            })
1582        );
1583    });
1584}
1585
1586#[gpui::test]
1587async fn test_truncate_second_message(cx: &mut TestAppContext) {
1588    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1589    let fake_model = model.as_fake();
1590
1591    thread
1592        .update(cx, |thread, cx| {
1593            thread.send(UserMessageId::new(), ["Message 1"], cx)
1594        })
1595        .unwrap();
1596    cx.run_until_parked();
1597    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1598    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1599        language_model::TokenUsage {
1600            input_tokens: 32_000,
1601            output_tokens: 16_000,
1602            cache_creation_input_tokens: 0,
1603            cache_read_input_tokens: 0,
1604        },
1605    ));
1606    fake_model.end_last_completion_stream();
1607    cx.run_until_parked();
1608
1609    let assert_first_message_state = |cx: &mut TestAppContext| {
1610        thread.clone().read_with(cx, |thread, _| {
1611            assert_eq!(
1612                thread.to_markdown(),
1613                indoc! {"
1614                    ## User
1615
1616                    Message 1
1617
1618                    ## Assistant
1619
1620                    Message 1 response
1621                "}
1622            );
1623
1624            assert_eq!(
1625                thread.latest_token_usage(),
1626                Some(acp_thread::TokenUsage {
1627                    used_tokens: 32_000 + 16_000,
1628                    max_tokens: 1_000_000,
1629                })
1630            );
1631        });
1632    };
1633
1634    assert_first_message_state(cx);
1635
1636    let second_message_id = UserMessageId::new();
1637    thread
1638        .update(cx, |thread, cx| {
1639            thread.send(second_message_id.clone(), ["Message 2"], cx)
1640        })
1641        .unwrap();
1642    cx.run_until_parked();
1643
1644    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1645    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1646        language_model::TokenUsage {
1647            input_tokens: 40_000,
1648            output_tokens: 20_000,
1649            cache_creation_input_tokens: 0,
1650            cache_read_input_tokens: 0,
1651        },
1652    ));
1653    fake_model.end_last_completion_stream();
1654    cx.run_until_parked();
1655
1656    thread.read_with(cx, |thread, _| {
1657        assert_eq!(
1658            thread.to_markdown(),
1659            indoc! {"
1660                ## User
1661
1662                Message 1
1663
1664                ## Assistant
1665
1666                Message 1 response
1667
1668                ## User
1669
1670                Message 2
1671
1672                ## Assistant
1673
1674                Message 2 response
1675            "}
1676        );
1677
1678        assert_eq!(
1679            thread.latest_token_usage(),
1680            Some(acp_thread::TokenUsage {
1681                used_tokens: 40_000 + 20_000,
1682                max_tokens: 1_000_000,
1683            })
1684        );
1685    });
1686
1687    thread
1688        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1689        .unwrap();
1690    cx.run_until_parked();
1691
1692    assert_first_message_state(cx);
1693}
1694
1695#[gpui::test]
1696async fn test_title_generation(cx: &mut TestAppContext) {
1697    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1698    let fake_model = model.as_fake();
1699
1700    let summary_model = Arc::new(FakeLanguageModel::default());
1701    thread.update(cx, |thread, cx| {
1702        thread.set_summarization_model(Some(summary_model.clone()), cx)
1703    });
1704
1705    let send = thread
1706        .update(cx, |thread, cx| {
1707            thread.send(UserMessageId::new(), ["Hello"], cx)
1708        })
1709        .unwrap();
1710    cx.run_until_parked();
1711
1712    fake_model.send_last_completion_stream_text_chunk("Hey!");
1713    fake_model.end_last_completion_stream();
1714    cx.run_until_parked();
1715    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1716
1717    // Ensure the summary model has been invoked to generate a title.
1718    summary_model.send_last_completion_stream_text_chunk("Hello ");
1719    summary_model.send_last_completion_stream_text_chunk("world\nG");
1720    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1721    summary_model.end_last_completion_stream();
1722    send.collect::<Vec<_>>().await;
1723    cx.run_until_parked();
1724    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1725
1726    // Send another message, ensuring no title is generated this time.
1727    let send = thread
1728        .update(cx, |thread, cx| {
1729            thread.send(UserMessageId::new(), ["Hello again"], cx)
1730        })
1731        .unwrap();
1732    cx.run_until_parked();
1733    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1734    fake_model.end_last_completion_stream();
1735    cx.run_until_parked();
1736    assert_eq!(summary_model.pending_completions(), Vec::new());
1737    send.collect::<Vec<_>>().await;
1738    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1739}
1740
1741#[gpui::test]
1742async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1743    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1744    let fake_model = model.as_fake();
1745
1746    let _events = thread
1747        .update(cx, |thread, cx| {
1748            thread.add_tool(ToolRequiringPermission);
1749            thread.add_tool(EchoTool);
1750            thread.send(UserMessageId::new(), ["Hey!"], cx)
1751        })
1752        .unwrap();
1753    cx.run_until_parked();
1754
1755    let permission_tool_use = LanguageModelToolUse {
1756        id: "tool_id_1".into(),
1757        name: ToolRequiringPermission::name().into(),
1758        raw_input: "{}".into(),
1759        input: json!({}),
1760        is_input_complete: true,
1761    };
1762    let echo_tool_use = LanguageModelToolUse {
1763        id: "tool_id_2".into(),
1764        name: EchoTool::name().into(),
1765        raw_input: json!({"text": "test"}).to_string(),
1766        input: json!({"text": "test"}),
1767        is_input_complete: true,
1768    };
1769    fake_model.send_last_completion_stream_text_chunk("Hi!");
1770    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1771        permission_tool_use,
1772    ));
1773    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1774        echo_tool_use.clone(),
1775    ));
1776    fake_model.end_last_completion_stream();
1777    cx.run_until_parked();
1778
1779    // Ensure pending tools are skipped when building a request.
1780    let request = thread
1781        .read_with(cx, |thread, cx| {
1782            thread.build_completion_request(CompletionIntent::EditFile, cx)
1783        })
1784        .unwrap();
1785    assert_eq!(
1786        request.messages[1..],
1787        vec![
1788            LanguageModelRequestMessage {
1789                role: Role::User,
1790                content: vec!["Hey!".into()],
1791                cache: true
1792            },
1793            LanguageModelRequestMessage {
1794                role: Role::Assistant,
1795                content: vec![
1796                    MessageContent::Text("Hi!".into()),
1797                    MessageContent::ToolUse(echo_tool_use.clone())
1798                ],
1799                cache: false
1800            },
1801            LanguageModelRequestMessage {
1802                role: Role::User,
1803                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1804                    tool_use_id: echo_tool_use.id.clone(),
1805                    tool_name: echo_tool_use.name,
1806                    is_error: false,
1807                    content: "test".into(),
1808                    output: Some("test".into())
1809                })],
1810                cache: false
1811            },
1812        ],
1813    );
1814}
1815
1816#[gpui::test]
1817async fn test_agent_connection(cx: &mut TestAppContext) {
1818    cx.update(settings::init);
1819    let templates = Templates::new();
1820
1821    // Initialize language model system with test provider
1822    cx.update(|cx| {
1823        gpui_tokio::init(cx);
1824        client::init_settings(cx);
1825
1826        let http_client = FakeHttpClient::with_404_response();
1827        let clock = Arc::new(clock::FakeSystemClock::new());
1828        let client = Client::new(clock, http_client, cx);
1829        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1830        Project::init_settings(cx);
1831        agent_settings::init(cx);
1832        language_model::init(client.clone(), cx);
1833        language_models::init(user_store, client.clone(), cx);
1834        LanguageModelRegistry::test(cx);
1835    });
1836    cx.executor().forbid_parking();
1837
1838    // Create a project for new_thread
1839    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1840    fake_fs.insert_tree(path!("/test"), json!({})).await;
1841    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1842    let cwd = Path::new("/test");
1843    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1844    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1845
1846    // Create agent and connection
1847    let agent = NativeAgent::new(
1848        project.clone(),
1849        history_store,
1850        templates.clone(),
1851        None,
1852        fake_fs.clone(),
1853        &mut cx.to_async(),
1854    )
1855    .await
1856    .unwrap();
1857    let connection = NativeAgentConnection(agent.clone());
1858
1859    // Test model_selector returns Some
1860    let selector_opt = connection.model_selector();
1861    assert!(
1862        selector_opt.is_some(),
1863        "agent2 should always support ModelSelector"
1864    );
1865    let selector = selector_opt.unwrap();
1866
1867    // Test list_models
1868    let listed_models = cx
1869        .update(|cx| selector.list_models(cx))
1870        .await
1871        .expect("list_models should succeed");
1872    let AgentModelList::Grouped(listed_models) = listed_models else {
1873        panic!("Unexpected model list type");
1874    };
1875    assert!(!listed_models.is_empty(), "should have at least one model");
1876    assert_eq!(
1877        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1878        "fake/fake"
1879    );
1880
1881    // Create a thread using new_thread
1882    let connection_rc = Rc::new(connection.clone());
1883    let acp_thread = cx
1884        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1885        .await
1886        .expect("new_thread should succeed");
1887
1888    // Get the session_id from the AcpThread
1889    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1890
1891    // Test selected_model returns the default
1892    let model = cx
1893        .update(|cx| selector.selected_model(&session_id, cx))
1894        .await
1895        .expect("selected_model should succeed");
1896    let model = cx
1897        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1898        .unwrap();
1899    let model = model.as_fake();
1900    assert_eq!(model.id().0, "fake", "should return default model");
1901
1902    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1903    cx.run_until_parked();
1904    model.send_last_completion_stream_text_chunk("def");
1905    cx.run_until_parked();
1906    acp_thread.read_with(cx, |thread, cx| {
1907        assert_eq!(
1908            thread.to_markdown(cx),
1909            indoc! {"
1910                ## User
1911
1912                abc
1913
1914                ## Assistant
1915
1916                def
1917
1918            "}
1919        )
1920    });
1921
1922    // Test cancel
1923    cx.update(|cx| connection.cancel(&session_id, cx));
1924    request.await.expect("prompt should fail gracefully");
1925
1926    // Ensure that dropping the ACP thread causes the native thread to be
1927    // dropped as well.
1928    cx.update(|_| drop(acp_thread));
1929    let result = cx
1930        .update(|cx| {
1931            connection.prompt(
1932                Some(acp_thread::UserMessageId::new()),
1933                acp::PromptRequest {
1934                    session_id: session_id.clone(),
1935                    prompt: vec!["ghi".into()],
1936                },
1937                cx,
1938            )
1939        })
1940        .await;
1941    assert_eq!(
1942        result.as_ref().unwrap_err().to_string(),
1943        "Session not found",
1944        "unexpected result: {:?}",
1945        result
1946    );
1947}
1948
1949#[gpui::test]
1950async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1951    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1952    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1953    let fake_model = model.as_fake();
1954
1955    let mut events = thread
1956        .update(cx, |thread, cx| {
1957            thread.send(UserMessageId::new(), ["Think"], cx)
1958        })
1959        .unwrap();
1960    cx.run_until_parked();
1961
1962    // Simulate streaming partial input.
1963    let input = json!({});
1964    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1965        LanguageModelToolUse {
1966            id: "1".into(),
1967            name: ThinkingTool::name().into(),
1968            raw_input: input.to_string(),
1969            input,
1970            is_input_complete: false,
1971        },
1972    ));
1973
1974    // Input streaming completed
1975    let input = json!({ "content": "Thinking hard!" });
1976    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1977        LanguageModelToolUse {
1978            id: "1".into(),
1979            name: "thinking".into(),
1980            raw_input: input.to_string(),
1981            input,
1982            is_input_complete: true,
1983        },
1984    ));
1985    fake_model.end_last_completion_stream();
1986    cx.run_until_parked();
1987
1988    let tool_call = expect_tool_call(&mut events).await;
1989    assert_eq!(
1990        tool_call,
1991        acp::ToolCall {
1992            id: acp::ToolCallId("1".into()),
1993            title: "Thinking".into(),
1994            kind: acp::ToolKind::Think,
1995            status: acp::ToolCallStatus::Pending,
1996            content: vec![],
1997            locations: vec![],
1998            raw_input: Some(json!({})),
1999            raw_output: None,
2000        }
2001    );
2002    let update = expect_tool_call_update_fields(&mut events).await;
2003    assert_eq!(
2004        update,
2005        acp::ToolCallUpdate {
2006            id: acp::ToolCallId("1".into()),
2007            fields: acp::ToolCallUpdateFields {
2008                title: Some("Thinking".into()),
2009                kind: Some(acp::ToolKind::Think),
2010                raw_input: Some(json!({ "content": "Thinking hard!" })),
2011                ..Default::default()
2012            },
2013        }
2014    );
2015    let update = expect_tool_call_update_fields(&mut events).await;
2016    assert_eq!(
2017        update,
2018        acp::ToolCallUpdate {
2019            id: acp::ToolCallId("1".into()),
2020            fields: acp::ToolCallUpdateFields {
2021                status: Some(acp::ToolCallStatus::InProgress),
2022                ..Default::default()
2023            },
2024        }
2025    );
2026    let update = expect_tool_call_update_fields(&mut events).await;
2027    assert_eq!(
2028        update,
2029        acp::ToolCallUpdate {
2030            id: acp::ToolCallId("1".into()),
2031            fields: acp::ToolCallUpdateFields {
2032                content: Some(vec!["Thinking hard!".into()]),
2033                ..Default::default()
2034            },
2035        }
2036    );
2037    let update = expect_tool_call_update_fields(&mut events).await;
2038    assert_eq!(
2039        update,
2040        acp::ToolCallUpdate {
2041            id: acp::ToolCallId("1".into()),
2042            fields: acp::ToolCallUpdateFields {
2043                status: Some(acp::ToolCallStatus::Completed),
2044                raw_output: Some("Finished thinking.".into()),
2045                ..Default::default()
2046            },
2047        }
2048    );
2049}
2050
2051#[gpui::test]
2052async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2053    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2054    let fake_model = model.as_fake();
2055
2056    let mut events = thread
2057        .update(cx, |thread, cx| {
2058            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2059            thread.send(UserMessageId::new(), ["Hello!"], cx)
2060        })
2061        .unwrap();
2062    cx.run_until_parked();
2063
2064    fake_model.send_last_completion_stream_text_chunk("Hey!");
2065    fake_model.end_last_completion_stream();
2066
2067    let mut retry_events = Vec::new();
2068    while let Some(Ok(event)) = events.next().await {
2069        match event {
2070            ThreadEvent::Retry(retry_status) => {
2071                retry_events.push(retry_status);
2072            }
2073            ThreadEvent::Stop(..) => break,
2074            _ => {}
2075        }
2076    }
2077
2078    assert_eq!(retry_events.len(), 0);
2079    thread.read_with(cx, |thread, _cx| {
2080        assert_eq!(
2081            thread.to_markdown(),
2082            indoc! {"
2083                ## User
2084
2085                Hello!
2086
2087                ## Assistant
2088
2089                Hey!
2090            "}
2091        )
2092    });
2093}
2094
2095#[gpui::test]
2096async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2097    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2098    let fake_model = model.as_fake();
2099
2100    let mut events = thread
2101        .update(cx, |thread, cx| {
2102            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2103            thread.send(UserMessageId::new(), ["Hello!"], cx)
2104        })
2105        .unwrap();
2106    cx.run_until_parked();
2107
2108    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2109        provider: LanguageModelProviderName::new("Anthropic"),
2110        retry_after: Some(Duration::from_secs(3)),
2111    });
2112    fake_model.end_last_completion_stream();
2113
2114    cx.executor().advance_clock(Duration::from_secs(3));
2115    cx.run_until_parked();
2116
2117    fake_model.send_last_completion_stream_text_chunk("Hey!");
2118    fake_model.end_last_completion_stream();
2119
2120    let mut retry_events = Vec::new();
2121    while let Some(Ok(event)) = events.next().await {
2122        match event {
2123            ThreadEvent::Retry(retry_status) => {
2124                retry_events.push(retry_status);
2125            }
2126            ThreadEvent::Stop(..) => break,
2127            _ => {}
2128        }
2129    }
2130
2131    assert_eq!(retry_events.len(), 1);
2132    assert!(matches!(
2133        retry_events[0],
2134        acp_thread::RetryStatus { attempt: 1, .. }
2135    ));
2136    thread.read_with(cx, |thread, _cx| {
2137        assert_eq!(
2138            thread.to_markdown(),
2139            indoc! {"
2140                ## User
2141
2142                Hello!
2143
2144                ## Assistant
2145
2146                Hey!
2147            "}
2148        )
2149    });
2150}
2151
2152#[gpui::test]
2153async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2154    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2155    let fake_model = model.as_fake();
2156
2157    let mut events = thread
2158        .update(cx, |thread, cx| {
2159            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2160            thread.send(UserMessageId::new(), ["Hello!"], cx)
2161        })
2162        .unwrap();
2163    cx.run_until_parked();
2164
2165    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2166        fake_model.send_last_completion_stream_error(
2167            LanguageModelCompletionError::ServerOverloaded {
2168                provider: LanguageModelProviderName::new("Anthropic"),
2169                retry_after: Some(Duration::from_secs(3)),
2170            },
2171        );
2172        fake_model.end_last_completion_stream();
2173        cx.executor().advance_clock(Duration::from_secs(3));
2174        cx.run_until_parked();
2175    }
2176
2177    let mut errors = Vec::new();
2178    let mut retry_events = Vec::new();
2179    while let Some(event) = events.next().await {
2180        match event {
2181            Ok(ThreadEvent::Retry(retry_status)) => {
2182                retry_events.push(retry_status);
2183            }
2184            Ok(ThreadEvent::Stop(..)) => break,
2185            Err(error) => errors.push(error),
2186            _ => {}
2187        }
2188    }
2189
2190    assert_eq!(
2191        retry_events.len(),
2192        crate::thread::MAX_RETRY_ATTEMPTS as usize
2193    );
2194    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2195        assert_eq!(retry_events[i].attempt, i + 1);
2196    }
2197    assert_eq!(errors.len(), 1);
2198    let error = errors[0]
2199        .downcast_ref::<LanguageModelCompletionError>()
2200        .unwrap();
2201    assert!(matches!(
2202        error,
2203        LanguageModelCompletionError::ServerOverloaded { .. }
2204    ));
2205}
2206
2207/// Filters out the stop events for asserting against in tests
2208fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2209    result_events
2210        .into_iter()
2211        .filter_map(|event| match event.unwrap() {
2212            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2213            _ => None,
2214        })
2215        .collect()
2216}
2217
2218struct ThreadTest {
2219    model: Arc<dyn LanguageModel>,
2220    thread: Entity<Thread>,
2221    project_context: Entity<ProjectContext>,
2222    context_server_store: Entity<ContextServerStore>,
2223    fs: Arc<FakeFs>,
2224}
2225
2226enum TestModel {
2227    Sonnet4,
2228    Fake,
2229}
2230
2231impl TestModel {
2232    fn id(&self) -> LanguageModelId {
2233        match self {
2234            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2235            TestModel::Fake => unreachable!(),
2236        }
2237    }
2238}
2239
2240async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2241    cx.executor().allow_parking();
2242
2243    let fs = FakeFs::new(cx.background_executor.clone());
2244    fs.create_dir(paths::settings_file().parent().unwrap())
2245        .await
2246        .unwrap();
2247    fs.insert_file(
2248        paths::settings_file(),
2249        json!({
2250            "agent": {
2251                "default_profile": "test-profile",
2252                "profiles": {
2253                    "test-profile": {
2254                        "name": "Test Profile",
2255                        "tools": {
2256                            EchoTool::name(): true,
2257                            DelayTool::name(): true,
2258                            WordListTool::name(): true,
2259                            ToolRequiringPermission::name(): true,
2260                            InfiniteTool::name(): true,
2261                            ThinkingTool::name(): true,
2262                        }
2263                    }
2264                }
2265            }
2266        })
2267        .to_string()
2268        .into_bytes(),
2269    )
2270    .await;
2271
2272    cx.update(|cx| {
2273        settings::init(cx);
2274        Project::init_settings(cx);
2275        agent_settings::init(cx);
2276        gpui_tokio::init(cx);
2277        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2278        cx.set_http_client(Arc::new(http_client));
2279
2280        client::init_settings(cx);
2281        let client = Client::production(cx);
2282        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2283        language_model::init(client.clone(), cx);
2284        language_models::init(user_store, client.clone(), cx);
2285
2286        watch_settings(fs.clone(), cx);
2287    });
2288
2289    let templates = Templates::new();
2290
2291    fs.insert_tree(path!("/test"), json!({})).await;
2292    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2293
2294    let model = cx
2295        .update(|cx| {
2296            if let TestModel::Fake = model {
2297                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2298            } else {
2299                let model_id = model.id();
2300                let models = LanguageModelRegistry::read_global(cx);
2301                let model = models
2302                    .available_models(cx)
2303                    .find(|model| model.id() == model_id)
2304                    .unwrap();
2305
2306                let provider = models.provider(&model.provider_id()).unwrap();
2307                let authenticated = provider.authenticate(cx);
2308
2309                cx.spawn(async move |_cx| {
2310                    authenticated.await.unwrap();
2311                    model
2312                })
2313            }
2314        })
2315        .await;
2316
2317    let project_context = cx.new(|_cx| ProjectContext::default());
2318    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2319    let context_server_registry =
2320        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2321    let thread = cx.new(|cx| {
2322        Thread::new(
2323            project,
2324            project_context.clone(),
2325            context_server_registry,
2326            templates,
2327            Some(model.clone()),
2328            cx,
2329        )
2330    });
2331    ThreadTest {
2332        model,
2333        thread,
2334        project_context,
2335        context_server_store,
2336        fs,
2337    }
2338}
2339
2340#[cfg(test)]
2341#[ctor::ctor]
2342fn init_logger() {
2343    if std::env::var("RUST_LOG").is_ok() {
2344        env_logger::init();
2345    }
2346}
2347
2348fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2349    let fs = fs.clone();
2350    cx.spawn({
2351        async move |cx| {
2352            let mut new_settings_content_rx = settings::watch_config_file(
2353                cx.background_executor(),
2354                fs,
2355                paths::settings_file().clone(),
2356            );
2357
2358            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2359                cx.update(|cx| {
2360                    SettingsStore::update_global(cx, |settings, cx| {
2361                        settings.set_user_settings(&new_settings_content, cx)
2362                    })
2363                })
2364                .ok();
2365            }
2366        }
2367    })
2368    .detach();
2369}
2370
2371fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2372    completion
2373        .tools
2374        .iter()
2375        .map(|tool| tool.name.clone())
2376        .collect()
2377}
2378
2379fn setup_context_server(
2380    name: &'static str,
2381    tools: Vec<context_server::types::Tool>,
2382    context_server_store: &Entity<ContextServerStore>,
2383    cx: &mut TestAppContext,
2384) -> mpsc::UnboundedReceiver<(
2385    context_server::types::CallToolParams,
2386    oneshot::Sender<context_server::types::CallToolResponse>,
2387)> {
2388    cx.update(|cx| {
2389        let mut settings = ProjectSettings::get_global(cx).clone();
2390        settings.context_servers.insert(
2391            name.into(),
2392            project::project_settings::ContextServerSettings::Custom {
2393                enabled: true,
2394                command: ContextServerCommand {
2395                    path: "somebinary".into(),
2396                    args: Vec::new(),
2397                    env: None,
2398                },
2399            },
2400        );
2401        ProjectSettings::override_global(settings, cx);
2402    });
2403
2404    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2405    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2406        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2407            context_server::types::InitializeResponse {
2408                protocol_version: context_server::types::ProtocolVersion(
2409                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2410                ),
2411                server_info: context_server::types::Implementation {
2412                    name: name.into(),
2413                    version: "1.0.0".to_string(),
2414                },
2415                capabilities: context_server::types::ServerCapabilities {
2416                    tools: Some(context_server::types::ToolsCapabilities {
2417                        list_changed: Some(true),
2418                    }),
2419                    ..Default::default()
2420                },
2421                meta: None,
2422            }
2423        })
2424        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2425            let tools = tools.clone();
2426            async move {
2427                context_server::types::ListToolsResponse {
2428                    tools,
2429                    next_cursor: None,
2430                    meta: None,
2431                }
2432            }
2433        })
2434        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2435            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2436            async move {
2437                let (response_tx, response_rx) = oneshot::channel();
2438                mcp_tool_calls_tx
2439                    .unbounded_send((params, response_tx))
2440                    .unwrap();
2441                response_rx.await.unwrap()
2442            }
2443        });
2444    context_server_store.update(cx, |store, cx| {
2445        store.start_server(
2446            Arc::new(ContextServer::new(
2447                ContextServerId(name.into()),
2448                Arc::new(fake_transport),
2449            )),
2450            cx,
2451        );
2452    });
2453    cx.run_until_parked();
2454    mcp_tool_calls_rx
2455}