mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true
 219        }]
 220    );
 221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 222        "Response to Message 1".into(),
 223    ));
 224    fake_model.end_last_completion_stream();
 225    cx.run_until_parked();
 226
 227    // Send another user message and verify only the latest is cached
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Message 2"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let completion = fake_model.pending_completions().pop().unwrap();
 236    assert_eq!(
 237        completion.messages[1..],
 238        vec![
 239            LanguageModelRequestMessage {
 240                role: Role::User,
 241                content: vec!["Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::Assistant,
 246                content: vec!["Response to Message 1".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::User,
 251                content: vec!["Message 2".into()],
 252                cache: true
 253            }
 254        ]
 255    );
 256    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 257        "Response to Message 2".into(),
 258    ));
 259    fake_model.end_last_completion_stream();
 260    cx.run_until_parked();
 261
 262    // Simulate a tool call and verify that the latest tool result is cached
 263    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 264    thread
 265        .update(cx, |thread, cx| {
 266            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 267        })
 268        .unwrap();
 269    cx.run_until_parked();
 270
 271    let tool_use = LanguageModelToolUse {
 272        id: "tool_1".into(),
 273        name: EchoTool::name().into(),
 274        raw_input: json!({"text": "test"}).to_string(),
 275        input: json!({"text": "test"}),
 276        is_input_complete: true,
 277        thought_signature: None,
 278    };
 279    fake_model
 280        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 281    fake_model.end_last_completion_stream();
 282    cx.run_until_parked();
 283
 284    let completion = fake_model.pending_completions().pop().unwrap();
 285    let tool_result = LanguageModelToolResult {
 286        tool_use_id: "tool_1".into(),
 287        tool_name: EchoTool::name().into(),
 288        is_error: false,
 289        content: "test".into(),
 290        output: Some("test".into()),
 291    };
 292    assert_eq!(
 293        completion.messages[1..],
 294        vec![
 295            LanguageModelRequestMessage {
 296                role: Role::User,
 297                content: vec!["Message 1".into()],
 298                cache: false
 299            },
 300            LanguageModelRequestMessage {
 301                role: Role::Assistant,
 302                content: vec!["Response to Message 1".into()],
 303                cache: false
 304            },
 305            LanguageModelRequestMessage {
 306                role: Role::User,
 307                content: vec!["Message 2".into()],
 308                cache: false
 309            },
 310            LanguageModelRequestMessage {
 311                role: Role::Assistant,
 312                content: vec!["Response to Message 2".into()],
 313                cache: false
 314            },
 315            LanguageModelRequestMessage {
 316                role: Role::User,
 317                content: vec!["Use the echo tool".into()],
 318                cache: false
 319            },
 320            LanguageModelRequestMessage {
 321                role: Role::Assistant,
 322                content: vec![MessageContent::ToolUse(tool_use)],
 323                cache: false
 324            },
 325            LanguageModelRequestMessage {
 326                role: Role::User,
 327                content: vec![MessageContent::ToolResult(tool_result)],
 328                cache: true
 329            }
 330        ]
 331    );
 332}
 333
 334#[gpui::test]
 335#[cfg_attr(not(feature = "e2e"), ignore)]
 336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 337    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 338
 339    // Test a tool call that's likely to complete *before* streaming stops.
 340    let events = thread
 341        .update(cx, |thread, cx| {
 342            thread.add_tool(EchoTool);
 343            thread.send(
 344                UserMessageId::new(),
 345                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 346                cx,
 347            )
 348        })
 349        .unwrap()
 350        .collect()
 351        .await;
 352    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 353
 354    // Test a tool calls that's likely to complete *after* streaming stops.
 355    let events = thread
 356        .update(cx, |thread, cx| {
 357            thread.remove_tool(&EchoTool::name());
 358            thread.add_tool(DelayTool);
 359            thread.send(
 360                UserMessageId::new(),
 361                [
 362                    "Now call the delay tool with 200ms.",
 363                    "When the timer goes off, then you echo the output of the tool.",
 364                ],
 365                cx,
 366            )
 367        })
 368        .unwrap()
 369        .collect()
 370        .await;
 371    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 372    thread.update(cx, |thread, _cx| {
 373        assert!(
 374            thread
 375                .last_message()
 376                .unwrap()
 377                .as_agent_message()
 378                .unwrap()
 379                .content
 380                .iter()
 381                .any(|content| {
 382                    if let AgentMessageContent::Text(text) = content {
 383                        text.contains("Ding")
 384                    } else {
 385                        false
 386                    }
 387                }),
 388            "{}",
 389            thread.to_markdown()
 390        );
 391    });
 392}
 393
 394#[gpui::test]
 395#[cfg_attr(not(feature = "e2e"), ignore)]
 396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 397    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 398
 399    // Test a tool call that's likely to complete *before* streaming stops.
 400    let mut events = thread
 401        .update(cx, |thread, cx| {
 402            thread.add_tool(WordListTool);
 403            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 404        })
 405        .unwrap();
 406
 407    let mut saw_partial_tool_use = false;
 408    while let Some(event) = events.next().await {
 409        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 410            thread.update(cx, |thread, _cx| {
 411                // Look for a tool use in the thread's last message
 412                let message = thread.last_message().unwrap();
 413                let agent_message = message.as_agent_message().unwrap();
 414                let last_content = agent_message.content.last().unwrap();
 415                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 416                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 417                    if tool_call.status == acp::ToolCallStatus::Pending {
 418                        if !last_tool_use.is_input_complete
 419                            && last_tool_use.input.get("g").is_none()
 420                        {
 421                            saw_partial_tool_use = true;
 422                        }
 423                    } else {
 424                        last_tool_use
 425                            .input
 426                            .get("a")
 427                            .expect("'a' has streamed because input is now complete");
 428                        last_tool_use
 429                            .input
 430                            .get("g")
 431                            .expect("'g' has streamed because input is now complete");
 432                    }
 433                } else {
 434                    panic!("last content should be a tool use");
 435                }
 436            });
 437        }
 438    }
 439
 440    assert!(
 441        saw_partial_tool_use,
 442        "should see at least one partially streamed tool use in the history"
 443    );
 444}
 445
 446#[gpui::test]
 447async fn test_tool_authorization(cx: &mut TestAppContext) {
 448    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 449    let fake_model = model.as_fake();
 450
 451    let mut events = thread
 452        .update(cx, |thread, cx| {
 453            thread.add_tool(ToolRequiringPermission);
 454            thread.send(UserMessageId::new(), ["abc"], cx)
 455        })
 456        .unwrap();
 457    cx.run_until_parked();
 458    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 459        LanguageModelToolUse {
 460            id: "tool_id_1".into(),
 461            name: ToolRequiringPermission::name().into(),
 462            raw_input: "{}".into(),
 463            input: json!({}),
 464            is_input_complete: true,
 465            thought_signature: None,
 466        },
 467    ));
 468    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 469        LanguageModelToolUse {
 470            id: "tool_id_2".into(),
 471            name: ToolRequiringPermission::name().into(),
 472            raw_input: "{}".into(),
 473            input: json!({}),
 474            is_input_complete: true,
 475            thought_signature: None,
 476        },
 477    ));
 478    fake_model.end_last_completion_stream();
 479    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 480    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 481
 482    // Approve the first
 483    tool_call_auth_1
 484        .response
 485        .send(tool_call_auth_1.options[1].id.clone())
 486        .unwrap();
 487    cx.run_until_parked();
 488
 489    // Reject the second
 490    tool_call_auth_2
 491        .response
 492        .send(tool_call_auth_1.options[2].id.clone())
 493        .unwrap();
 494    cx.run_until_parked();
 495
 496    let completion = fake_model.pending_completions().pop().unwrap();
 497    let message = completion.messages.last().unwrap();
 498    assert_eq!(
 499        message.content,
 500        vec![
 501            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 502                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 503                tool_name: ToolRequiringPermission::name().into(),
 504                is_error: false,
 505                content: "Allowed".into(),
 506                output: Some("Allowed".into())
 507            }),
 508            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 509                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 510                tool_name: ToolRequiringPermission::name().into(),
 511                is_error: true,
 512                content: "Permission to run tool denied by user".into(),
 513                output: Some("Permission to run tool denied by user".into())
 514            })
 515        ]
 516    );
 517
 518    // Simulate yet another tool call.
 519    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 520        LanguageModelToolUse {
 521            id: "tool_id_3".into(),
 522            name: ToolRequiringPermission::name().into(),
 523            raw_input: "{}".into(),
 524            input: json!({}),
 525            is_input_complete: true,
 526            thought_signature: None,
 527        },
 528    ));
 529    fake_model.end_last_completion_stream();
 530
 531    // Respond by always allowing tools.
 532    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 533    tool_call_auth_3
 534        .response
 535        .send(tool_call_auth_3.options[0].id.clone())
 536        .unwrap();
 537    cx.run_until_parked();
 538    let completion = fake_model.pending_completions().pop().unwrap();
 539    let message = completion.messages.last().unwrap();
 540    assert_eq!(
 541        message.content,
 542        vec![language_model::MessageContent::ToolResult(
 543            LanguageModelToolResult {
 544                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 545                tool_name: ToolRequiringPermission::name().into(),
 546                is_error: false,
 547                content: "Allowed".into(),
 548                output: Some("Allowed".into())
 549            }
 550        )]
 551    );
 552
 553    // Simulate a final tool call, ensuring we don't trigger authorization.
 554    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 555        LanguageModelToolUse {
 556            id: "tool_id_4".into(),
 557            name: ToolRequiringPermission::name().into(),
 558            raw_input: "{}".into(),
 559            input: json!({}),
 560            is_input_complete: true,
 561            thought_signature: None,
 562        },
 563    ));
 564    fake_model.end_last_completion_stream();
 565    cx.run_until_parked();
 566    let completion = fake_model.pending_completions().pop().unwrap();
 567    let message = completion.messages.last().unwrap();
 568    assert_eq!(
 569        message.content,
 570        vec![language_model::MessageContent::ToolResult(
 571            LanguageModelToolResult {
 572                tool_use_id: "tool_id_4".into(),
 573                tool_name: ToolRequiringPermission::name().into(),
 574                is_error: false,
 575                content: "Allowed".into(),
 576                output: Some("Allowed".into())
 577            }
 578        )]
 579    );
 580}
 581
 582#[gpui::test]
 583async fn test_tool_hallucination(cx: &mut TestAppContext) {
 584    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 585    let fake_model = model.as_fake();
 586
 587    let mut events = thread
 588        .update(cx, |thread, cx| {
 589            thread.send(UserMessageId::new(), ["abc"], cx)
 590        })
 591        .unwrap();
 592    cx.run_until_parked();
 593    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 594        LanguageModelToolUse {
 595            id: "tool_id_1".into(),
 596            name: "nonexistent_tool".into(),
 597            raw_input: "{}".into(),
 598            input: json!({}),
 599            is_input_complete: true,
 600            thought_signature: None,
 601        },
 602    ));
 603    fake_model.end_last_completion_stream();
 604
 605    let tool_call = expect_tool_call(&mut events).await;
 606    assert_eq!(tool_call.title, "nonexistent_tool");
 607    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 608    let update = expect_tool_call_update_fields(&mut events).await;
 609    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 610}
 611
 612#[gpui::test]
 613async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 614    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 615    let fake_model = model.as_fake();
 616
 617    let events = thread
 618        .update(cx, |thread, cx| {
 619            thread.add_tool(EchoTool);
 620            thread.send(UserMessageId::new(), ["abc"], cx)
 621        })
 622        .unwrap();
 623    cx.run_until_parked();
 624    let tool_use = LanguageModelToolUse {
 625        id: "tool_id_1".into(),
 626        name: EchoTool::name().into(),
 627        raw_input: "{}".into(),
 628        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 629        is_input_complete: true,
 630        thought_signature: None,
 631    };
 632    fake_model
 633        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 634    fake_model.end_last_completion_stream();
 635
 636    cx.run_until_parked();
 637    let completion = fake_model.pending_completions().pop().unwrap();
 638    let tool_result = LanguageModelToolResult {
 639        tool_use_id: "tool_id_1".into(),
 640        tool_name: EchoTool::name().into(),
 641        is_error: false,
 642        content: "def".into(),
 643        output: Some("def".into()),
 644    };
 645    assert_eq!(
 646        completion.messages[1..],
 647        vec![
 648            LanguageModelRequestMessage {
 649                role: Role::User,
 650                content: vec!["abc".into()],
 651                cache: false
 652            },
 653            LanguageModelRequestMessage {
 654                role: Role::Assistant,
 655                content: vec![MessageContent::ToolUse(tool_use.clone())],
 656                cache: false
 657            },
 658            LanguageModelRequestMessage {
 659                role: Role::User,
 660                content: vec![MessageContent::ToolResult(tool_result.clone())],
 661                cache: true
 662            },
 663        ]
 664    );
 665
 666    // Simulate reaching tool use limit.
 667    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 668        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 669    ));
 670    fake_model.end_last_completion_stream();
 671    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 672    assert!(
 673        last_event
 674            .unwrap_err()
 675            .is::<language_model::ToolUseLimitReachedError>()
 676    );
 677
 678    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 679    cx.run_until_parked();
 680    let completion = fake_model.pending_completions().pop().unwrap();
 681    assert_eq!(
 682        completion.messages[1..],
 683        vec![
 684            LanguageModelRequestMessage {
 685                role: Role::User,
 686                content: vec!["abc".into()],
 687                cache: false
 688            },
 689            LanguageModelRequestMessage {
 690                role: Role::Assistant,
 691                content: vec![MessageContent::ToolUse(tool_use)],
 692                cache: false
 693            },
 694            LanguageModelRequestMessage {
 695                role: Role::User,
 696                content: vec![MessageContent::ToolResult(tool_result)],
 697                cache: false
 698            },
 699            LanguageModelRequestMessage {
 700                role: Role::User,
 701                content: vec!["Continue where you left off".into()],
 702                cache: true
 703            }
 704        ]
 705    );
 706
 707    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 708    fake_model.end_last_completion_stream();
 709    events.collect::<Vec<_>>().await;
 710    thread.read_with(cx, |thread, _cx| {
 711        assert_eq!(
 712            thread.last_message().unwrap().to_markdown(),
 713            indoc! {"
 714                ## Assistant
 715
 716                Done
 717            "}
 718        )
 719    });
 720}
 721
 722#[gpui::test]
 723async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 724    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 725    let fake_model = model.as_fake();
 726
 727    let events = thread
 728        .update(cx, |thread, cx| {
 729            thread.add_tool(EchoTool);
 730            thread.send(UserMessageId::new(), ["abc"], cx)
 731        })
 732        .unwrap();
 733    cx.run_until_parked();
 734
 735    let tool_use = LanguageModelToolUse {
 736        id: "tool_id_1".into(),
 737        name: EchoTool::name().into(),
 738        raw_input: "{}".into(),
 739        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 740        is_input_complete: true,
 741        thought_signature: None,
 742    };
 743    let tool_result = LanguageModelToolResult {
 744        tool_use_id: "tool_id_1".into(),
 745        tool_name: EchoTool::name().into(),
 746        is_error: false,
 747        content: "def".into(),
 748        output: Some("def".into()),
 749    };
 750    fake_model
 751        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 752    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 753        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 754    ));
 755    fake_model.end_last_completion_stream();
 756    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 757    assert!(
 758        last_event
 759            .unwrap_err()
 760            .is::<language_model::ToolUseLimitReachedError>()
 761    );
 762
 763    thread
 764        .update(cx, |thread, cx| {
 765            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 766        })
 767        .unwrap();
 768    cx.run_until_parked();
 769    let completion = fake_model.pending_completions().pop().unwrap();
 770    assert_eq!(
 771        completion.messages[1..],
 772        vec![
 773            LanguageModelRequestMessage {
 774                role: Role::User,
 775                content: vec!["abc".into()],
 776                cache: false
 777            },
 778            LanguageModelRequestMessage {
 779                role: Role::Assistant,
 780                content: vec![MessageContent::ToolUse(tool_use)],
 781                cache: false
 782            },
 783            LanguageModelRequestMessage {
 784                role: Role::User,
 785                content: vec![MessageContent::ToolResult(tool_result)],
 786                cache: false
 787            },
 788            LanguageModelRequestMessage {
 789                role: Role::User,
 790                content: vec!["ghi".into()],
 791                cache: true
 792            }
 793        ]
 794    );
 795}
 796
 797async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 798    let event = events
 799        .next()
 800        .await
 801        .expect("no tool call authorization event received")
 802        .unwrap();
 803    match event {
 804        ThreadEvent::ToolCall(tool_call) => tool_call,
 805        event => {
 806            panic!("Unexpected event {event:?}");
 807        }
 808    }
 809}
 810
 811async fn expect_tool_call_update_fields(
 812    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 813) -> acp::ToolCallUpdate {
 814    let event = events
 815        .next()
 816        .await
 817        .expect("no tool call authorization event received")
 818        .unwrap();
 819    match event {
 820        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 821        event => {
 822            panic!("Unexpected event {event:?}");
 823        }
 824    }
 825}
 826
 827async fn next_tool_call_authorization(
 828    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 829) -> ToolCallAuthorization {
 830    loop {
 831        let event = events
 832            .next()
 833            .await
 834            .expect("no tool call authorization event received")
 835            .unwrap();
 836        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 837            let permission_kinds = tool_call_authorization
 838                .options
 839                .iter()
 840                .map(|o| o.kind)
 841                .collect::<Vec<_>>();
 842            assert_eq!(
 843                permission_kinds,
 844                vec![
 845                    acp::PermissionOptionKind::AllowAlways,
 846                    acp::PermissionOptionKind::AllowOnce,
 847                    acp::PermissionOptionKind::RejectOnce,
 848                ]
 849            );
 850            return tool_call_authorization;
 851        }
 852    }
 853}
 854
 855#[gpui::test]
 856#[cfg_attr(not(feature = "e2e"), ignore)]
 857async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 858    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 859
 860    // Test concurrent tool calls with different delay times
 861    let events = thread
 862        .update(cx, |thread, cx| {
 863            thread.add_tool(DelayTool);
 864            thread.send(
 865                UserMessageId::new(),
 866                [
 867                    "Call the delay tool twice in the same message.",
 868                    "Once with 100ms. Once with 300ms.",
 869                    "When both timers are complete, describe the outputs.",
 870                ],
 871                cx,
 872            )
 873        })
 874        .unwrap()
 875        .collect()
 876        .await;
 877
 878    let stop_reasons = stop_events(events);
 879    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 880
 881    thread.update(cx, |thread, _cx| {
 882        let last_message = thread.last_message().unwrap();
 883        let agent_message = last_message.as_agent_message().unwrap();
 884        let text = agent_message
 885            .content
 886            .iter()
 887            .filter_map(|content| {
 888                if let AgentMessageContent::Text(text) = content {
 889                    Some(text.as_str())
 890                } else {
 891                    None
 892                }
 893            })
 894            .collect::<String>();
 895
 896        assert!(text.contains("Ding"));
 897    });
 898}
 899
 900#[gpui::test]
 901async fn test_profiles(cx: &mut TestAppContext) {
 902    let ThreadTest {
 903        model, thread, fs, ..
 904    } = setup(cx, TestModel::Fake).await;
 905    let fake_model = model.as_fake();
 906
 907    thread.update(cx, |thread, _cx| {
 908        thread.add_tool(DelayTool);
 909        thread.add_tool(EchoTool);
 910        thread.add_tool(InfiniteTool);
 911    });
 912
 913    // Override profiles and wait for settings to be loaded.
 914    fs.insert_file(
 915        paths::settings_file(),
 916        json!({
 917            "agent": {
 918                "profiles": {
 919                    "test-1": {
 920                        "name": "Test Profile 1",
 921                        "tools": {
 922                            EchoTool::name(): true,
 923                            DelayTool::name(): true,
 924                        }
 925                    },
 926                    "test-2": {
 927                        "name": "Test Profile 2",
 928                        "tools": {
 929                            InfiniteTool::name(): true,
 930                        }
 931                    }
 932                }
 933            }
 934        })
 935        .to_string()
 936        .into_bytes(),
 937    )
 938    .await;
 939    cx.run_until_parked();
 940
 941    // Test that test-1 profile (default) has echo and delay tools
 942    thread
 943        .update(cx, |thread, cx| {
 944            thread.set_profile(AgentProfileId("test-1".into()));
 945            thread.send(UserMessageId::new(), ["test"], cx)
 946        })
 947        .unwrap();
 948    cx.run_until_parked();
 949
 950    let mut pending_completions = fake_model.pending_completions();
 951    assert_eq!(pending_completions.len(), 1);
 952    let completion = pending_completions.pop().unwrap();
 953    let tool_names: Vec<String> = completion
 954        .tools
 955        .iter()
 956        .map(|tool| tool.name.clone())
 957        .collect();
 958    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 959    fake_model.end_last_completion_stream();
 960
 961    // Switch to test-2 profile, and verify that it has only the infinite tool.
 962    thread
 963        .update(cx, |thread, cx| {
 964            thread.set_profile(AgentProfileId("test-2".into()));
 965            thread.send(UserMessageId::new(), ["test2"], cx)
 966        })
 967        .unwrap();
 968    cx.run_until_parked();
 969    let mut pending_completions = fake_model.pending_completions();
 970    assert_eq!(pending_completions.len(), 1);
 971    let completion = pending_completions.pop().unwrap();
 972    let tool_names: Vec<String> = completion
 973        .tools
 974        .iter()
 975        .map(|tool| tool.name.clone())
 976        .collect();
 977    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 978}
 979
 980#[gpui::test]
 981async fn test_mcp_tools(cx: &mut TestAppContext) {
 982    let ThreadTest {
 983        model,
 984        thread,
 985        context_server_store,
 986        fs,
 987        ..
 988    } = setup(cx, TestModel::Fake).await;
 989    let fake_model = model.as_fake();
 990
 991    // Override profiles and wait for settings to be loaded.
 992    fs.insert_file(
 993        paths::settings_file(),
 994        json!({
 995            "agent": {
 996                "always_allow_tool_actions": true,
 997                "profiles": {
 998                    "test": {
 999                        "name": "Test Profile",
1000                        "enable_all_context_servers": true,
1001                        "tools": {
1002                            EchoTool::name(): true,
1003                        }
1004                    },
1005                }
1006            }
1007        })
1008        .to_string()
1009        .into_bytes(),
1010    )
1011    .await;
1012    cx.run_until_parked();
1013    thread.update(cx, |thread, _| {
1014        thread.set_profile(AgentProfileId("test".into()))
1015    });
1016
1017    let mut mcp_tool_calls = setup_context_server(
1018        "test_server",
1019        vec![context_server::types::Tool {
1020            name: "echo".into(),
1021            description: None,
1022            input_schema: serde_json::to_value(EchoTool::input_schema(
1023                LanguageModelToolSchemaFormat::JsonSchema,
1024            ))
1025            .unwrap(),
1026            output_schema: None,
1027            annotations: None,
1028        }],
1029        &context_server_store,
1030        cx,
1031    );
1032
1033    let events = thread.update(cx, |thread, cx| {
1034        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1035    });
1036    cx.run_until_parked();
1037
1038    // Simulate the model calling the MCP tool.
1039    let completion = fake_model.pending_completions().pop().unwrap();
1040    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1041    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1042        LanguageModelToolUse {
1043            id: "tool_1".into(),
1044            name: "echo".into(),
1045            raw_input: json!({"text": "test"}).to_string(),
1046            input: json!({"text": "test"}),
1047            is_input_complete: true,
1048            thought_signature: None,
1049        },
1050    ));
1051    fake_model.end_last_completion_stream();
1052    cx.run_until_parked();
1053
1054    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1055    assert_eq!(tool_call_params.name, "echo");
1056    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1057    tool_call_response
1058        .send(context_server::types::CallToolResponse {
1059            content: vec![context_server::types::ToolResponseContent::Text {
1060                text: "test".into(),
1061            }],
1062            is_error: None,
1063            meta: None,
1064            structured_content: None,
1065        })
1066        .unwrap();
1067    cx.run_until_parked();
1068
1069    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1070    fake_model.send_last_completion_stream_text_chunk("Done!");
1071    fake_model.end_last_completion_stream();
1072    events.collect::<Vec<_>>().await;
1073
1074    // Send again after adding the echo tool, ensuring the name collision is resolved.
1075    let events = thread.update(cx, |thread, cx| {
1076        thread.add_tool(EchoTool);
1077        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1078    });
1079    cx.run_until_parked();
1080    let completion = fake_model.pending_completions().pop().unwrap();
1081    assert_eq!(
1082        tool_names_for_completion(&completion),
1083        vec!["echo", "test_server_echo"]
1084    );
1085    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086        LanguageModelToolUse {
1087            id: "tool_2".into(),
1088            name: "test_server_echo".into(),
1089            raw_input: json!({"text": "mcp"}).to_string(),
1090            input: json!({"text": "mcp"}),
1091            is_input_complete: true,
1092            thought_signature: None,
1093        },
1094    ));
1095    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1096        LanguageModelToolUse {
1097            id: "tool_3".into(),
1098            name: "echo".into(),
1099            raw_input: json!({"text": "native"}).to_string(),
1100            input: json!({"text": "native"}),
1101            is_input_complete: true,
1102            thought_signature: None,
1103        },
1104    ));
1105    fake_model.end_last_completion_stream();
1106    cx.run_until_parked();
1107
1108    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1109    assert_eq!(tool_call_params.name, "echo");
1110    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1111    tool_call_response
1112        .send(context_server::types::CallToolResponse {
1113            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1114            is_error: None,
1115            meta: None,
1116            structured_content: None,
1117        })
1118        .unwrap();
1119    cx.run_until_parked();
1120
1121    // Ensure the tool results were inserted with the correct names.
1122    let completion = fake_model.pending_completions().pop().unwrap();
1123    assert_eq!(
1124        completion.messages.last().unwrap().content,
1125        vec![
1126            MessageContent::ToolResult(LanguageModelToolResult {
1127                tool_use_id: "tool_3".into(),
1128                tool_name: "echo".into(),
1129                is_error: false,
1130                content: "native".into(),
1131                output: Some("native".into()),
1132            },),
1133            MessageContent::ToolResult(LanguageModelToolResult {
1134                tool_use_id: "tool_2".into(),
1135                tool_name: "test_server_echo".into(),
1136                is_error: false,
1137                content: "mcp".into(),
1138                output: Some("mcp".into()),
1139            },),
1140        ]
1141    );
1142    fake_model.end_last_completion_stream();
1143    events.collect::<Vec<_>>().await;
1144}
1145
1146#[gpui::test]
1147async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1148    let ThreadTest {
1149        model,
1150        thread,
1151        context_server_store,
1152        fs,
1153        ..
1154    } = setup(cx, TestModel::Fake).await;
1155    let fake_model = model.as_fake();
1156
1157    // Set up a profile with all tools enabled
1158    fs.insert_file(
1159        paths::settings_file(),
1160        json!({
1161            "agent": {
1162                "profiles": {
1163                    "test": {
1164                        "name": "Test Profile",
1165                        "enable_all_context_servers": true,
1166                        "tools": {
1167                            EchoTool::name(): true,
1168                            DelayTool::name(): true,
1169                            WordListTool::name(): true,
1170                            ToolRequiringPermission::name(): true,
1171                            InfiniteTool::name(): true,
1172                        }
1173                    },
1174                }
1175            }
1176        })
1177        .to_string()
1178        .into_bytes(),
1179    )
1180    .await;
1181    cx.run_until_parked();
1182
1183    thread.update(cx, |thread, _| {
1184        thread.set_profile(AgentProfileId("test".into()));
1185        thread.add_tool(EchoTool);
1186        thread.add_tool(DelayTool);
1187        thread.add_tool(WordListTool);
1188        thread.add_tool(ToolRequiringPermission);
1189        thread.add_tool(InfiniteTool);
1190    });
1191
1192    // Set up multiple context servers with some overlapping tool names
1193    let _server1_calls = setup_context_server(
1194        "xxx",
1195        vec![
1196            context_server::types::Tool {
1197                name: "echo".into(), // Conflicts with native EchoTool
1198                description: None,
1199                input_schema: serde_json::to_value(EchoTool::input_schema(
1200                    LanguageModelToolSchemaFormat::JsonSchema,
1201                ))
1202                .unwrap(),
1203                output_schema: None,
1204                annotations: None,
1205            },
1206            context_server::types::Tool {
1207                name: "unique_tool_1".into(),
1208                description: None,
1209                input_schema: json!({"type": "object", "properties": {}}),
1210                output_schema: None,
1211                annotations: None,
1212            },
1213        ],
1214        &context_server_store,
1215        cx,
1216    );
1217
1218    let _server2_calls = setup_context_server(
1219        "yyy",
1220        vec![
1221            context_server::types::Tool {
1222                name: "echo".into(), // Also conflicts with native EchoTool
1223                description: None,
1224                input_schema: serde_json::to_value(EchoTool::input_schema(
1225                    LanguageModelToolSchemaFormat::JsonSchema,
1226                ))
1227                .unwrap(),
1228                output_schema: None,
1229                annotations: None,
1230            },
1231            context_server::types::Tool {
1232                name: "unique_tool_2".into(),
1233                description: None,
1234                input_schema: json!({"type": "object", "properties": {}}),
1235                output_schema: None,
1236                annotations: None,
1237            },
1238            context_server::types::Tool {
1239                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1240                description: None,
1241                input_schema: json!({"type": "object", "properties": {}}),
1242                output_schema: None,
1243                annotations: None,
1244            },
1245            context_server::types::Tool {
1246                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1247                description: None,
1248                input_schema: json!({"type": "object", "properties": {}}),
1249                output_schema: None,
1250                annotations: None,
1251            },
1252        ],
1253        &context_server_store,
1254        cx,
1255    );
1256    let _server3_calls = setup_context_server(
1257        "zzz",
1258        vec![
1259            context_server::types::Tool {
1260                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1261                description: None,
1262                input_schema: json!({"type": "object", "properties": {}}),
1263                output_schema: None,
1264                annotations: None,
1265            },
1266            context_server::types::Tool {
1267                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1268                description: None,
1269                input_schema: json!({"type": "object", "properties": {}}),
1270                output_schema: None,
1271                annotations: None,
1272            },
1273            context_server::types::Tool {
1274                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1275                description: None,
1276                input_schema: json!({"type": "object", "properties": {}}),
1277                output_schema: None,
1278                annotations: None,
1279            },
1280        ],
1281        &context_server_store,
1282        cx,
1283    );
1284
1285    thread
1286        .update(cx, |thread, cx| {
1287            thread.send(UserMessageId::new(), ["Go"], cx)
1288        })
1289        .unwrap();
1290    cx.run_until_parked();
1291    let completion = fake_model.pending_completions().pop().unwrap();
1292    assert_eq!(
1293        tool_names_for_completion(&completion),
1294        vec![
1295            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1296            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1297            "delay",
1298            "echo",
1299            "infinite",
1300            "tool_requiring_permission",
1301            "unique_tool_1",
1302            "unique_tool_2",
1303            "word_list",
1304            "xxx_echo",
1305            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1306            "yyy_echo",
1307            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1308        ]
1309    );
1310}
1311
1312#[gpui::test]
1313#[cfg_attr(not(feature = "e2e"), ignore)]
1314async fn test_cancellation(cx: &mut TestAppContext) {
1315    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1316
1317    let mut events = thread
1318        .update(cx, |thread, cx| {
1319            thread.add_tool(InfiniteTool);
1320            thread.add_tool(EchoTool);
1321            thread.send(
1322                UserMessageId::new(),
1323                ["Call the echo tool, then call the infinite tool, then explain their output"],
1324                cx,
1325            )
1326        })
1327        .unwrap();
1328
1329    // Wait until both tools are called.
1330    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1331    let mut echo_id = None;
1332    let mut echo_completed = false;
1333    while let Some(event) = events.next().await {
1334        match event.unwrap() {
1335            ThreadEvent::ToolCall(tool_call) => {
1336                assert_eq!(tool_call.title, expected_tools.remove(0));
1337                if tool_call.title == "Echo" {
1338                    echo_id = Some(tool_call.id);
1339                }
1340            }
1341            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1342                acp::ToolCallUpdate {
1343                    id,
1344                    fields:
1345                        acp::ToolCallUpdateFields {
1346                            status: Some(acp::ToolCallStatus::Completed),
1347                            ..
1348                        },
1349                    meta: None,
1350                },
1351            )) if Some(&id) == echo_id.as_ref() => {
1352                echo_completed = true;
1353            }
1354            _ => {}
1355        }
1356
1357        if expected_tools.is_empty() && echo_completed {
1358            break;
1359        }
1360    }
1361
1362    // Cancel the current send and ensure that the event stream is closed, even
1363    // if one of the tools is still running.
1364    thread.update(cx, |thread, cx| thread.cancel(cx));
1365    let events = events.collect::<Vec<_>>().await;
1366    let last_event = events.last();
1367    assert!(
1368        matches!(
1369            last_event,
1370            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1371        ),
1372        "unexpected event {last_event:?}"
1373    );
1374
1375    // Ensure we can still send a new message after cancellation.
1376    let events = thread
1377        .update(cx, |thread, cx| {
1378            thread.send(
1379                UserMessageId::new(),
1380                ["Testing: reply with 'Hello' then stop."],
1381                cx,
1382            )
1383        })
1384        .unwrap()
1385        .collect::<Vec<_>>()
1386        .await;
1387    thread.update(cx, |thread, _cx| {
1388        let message = thread.last_message().unwrap();
1389        let agent_message = message.as_agent_message().unwrap();
1390        assert_eq!(
1391            agent_message.content,
1392            vec![AgentMessageContent::Text("Hello".to_string())]
1393        );
1394    });
1395    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1396}
1397
1398#[gpui::test]
1399async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1400    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1401    let fake_model = model.as_fake();
1402
1403    let events_1 = thread
1404        .update(cx, |thread, cx| {
1405            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1406        })
1407        .unwrap();
1408    cx.run_until_parked();
1409    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1410    cx.run_until_parked();
1411
1412    let events_2 = thread
1413        .update(cx, |thread, cx| {
1414            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1415        })
1416        .unwrap();
1417    cx.run_until_parked();
1418    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1419    fake_model
1420        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1421    fake_model.end_last_completion_stream();
1422
1423    let events_1 = events_1.collect::<Vec<_>>().await;
1424    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1425    let events_2 = events_2.collect::<Vec<_>>().await;
1426    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1427}
1428
1429#[gpui::test]
1430async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1431    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1432    let fake_model = model.as_fake();
1433
1434    let events_1 = thread
1435        .update(cx, |thread, cx| {
1436            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1437        })
1438        .unwrap();
1439    cx.run_until_parked();
1440    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1441    fake_model
1442        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1443    fake_model.end_last_completion_stream();
1444    let events_1 = events_1.collect::<Vec<_>>().await;
1445
1446    let events_2 = thread
1447        .update(cx, |thread, cx| {
1448            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1449        })
1450        .unwrap();
1451    cx.run_until_parked();
1452    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1453    fake_model
1454        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1455    fake_model.end_last_completion_stream();
1456    let events_2 = events_2.collect::<Vec<_>>().await;
1457
1458    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1459    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1460}
1461
1462#[gpui::test]
1463async fn test_refusal(cx: &mut TestAppContext) {
1464    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1465    let fake_model = model.as_fake();
1466
1467    let events = thread
1468        .update(cx, |thread, cx| {
1469            thread.send(UserMessageId::new(), ["Hello"], cx)
1470        })
1471        .unwrap();
1472    cx.run_until_parked();
1473    thread.read_with(cx, |thread, _| {
1474        assert_eq!(
1475            thread.to_markdown(),
1476            indoc! {"
1477                ## User
1478
1479                Hello
1480            "}
1481        );
1482    });
1483
1484    fake_model.send_last_completion_stream_text_chunk("Hey!");
1485    cx.run_until_parked();
1486    thread.read_with(cx, |thread, _| {
1487        assert_eq!(
1488            thread.to_markdown(),
1489            indoc! {"
1490                ## User
1491
1492                Hello
1493
1494                ## Assistant
1495
1496                Hey!
1497            "}
1498        );
1499    });
1500
1501    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1502    fake_model
1503        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1504    let events = events.collect::<Vec<_>>().await;
1505    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1506    thread.read_with(cx, |thread, _| {
1507        assert_eq!(thread.to_markdown(), "");
1508    });
1509}
1510
1511#[gpui::test]
1512async fn test_truncate_first_message(cx: &mut TestAppContext) {
1513    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1514    let fake_model = model.as_fake();
1515
1516    let message_id = UserMessageId::new();
1517    thread
1518        .update(cx, |thread, cx| {
1519            thread.send(message_id.clone(), ["Hello"], cx)
1520        })
1521        .unwrap();
1522    cx.run_until_parked();
1523    thread.read_with(cx, |thread, _| {
1524        assert_eq!(
1525            thread.to_markdown(),
1526            indoc! {"
1527                ## User
1528
1529                Hello
1530            "}
1531        );
1532        assert_eq!(thread.latest_token_usage(), None);
1533    });
1534
1535    fake_model.send_last_completion_stream_text_chunk("Hey!");
1536    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1537        language_model::TokenUsage {
1538            input_tokens: 32_000,
1539            output_tokens: 16_000,
1540            cache_creation_input_tokens: 0,
1541            cache_read_input_tokens: 0,
1542        },
1543    ));
1544    cx.run_until_parked();
1545    thread.read_with(cx, |thread, _| {
1546        assert_eq!(
1547            thread.to_markdown(),
1548            indoc! {"
1549                ## User
1550
1551                Hello
1552
1553                ## Assistant
1554
1555                Hey!
1556            "}
1557        );
1558        assert_eq!(
1559            thread.latest_token_usage(),
1560            Some(acp_thread::TokenUsage {
1561                used_tokens: 32_000 + 16_000,
1562                max_tokens: 1_000_000,
1563            })
1564        );
1565    });
1566
1567    thread
1568        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1569        .unwrap();
1570    cx.run_until_parked();
1571    thread.read_with(cx, |thread, _| {
1572        assert_eq!(thread.to_markdown(), "");
1573        assert_eq!(thread.latest_token_usage(), None);
1574    });
1575
1576    // Ensure we can still send a new message after truncation.
1577    thread
1578        .update(cx, |thread, cx| {
1579            thread.send(UserMessageId::new(), ["Hi"], cx)
1580        })
1581        .unwrap();
1582    thread.update(cx, |thread, _cx| {
1583        assert_eq!(
1584            thread.to_markdown(),
1585            indoc! {"
1586                ## User
1587
1588                Hi
1589            "}
1590        );
1591    });
1592    cx.run_until_parked();
1593    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1594    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1595        language_model::TokenUsage {
1596            input_tokens: 40_000,
1597            output_tokens: 20_000,
1598            cache_creation_input_tokens: 0,
1599            cache_read_input_tokens: 0,
1600        },
1601    ));
1602    cx.run_until_parked();
1603    thread.read_with(cx, |thread, _| {
1604        assert_eq!(
1605            thread.to_markdown(),
1606            indoc! {"
1607                ## User
1608
1609                Hi
1610
1611                ## Assistant
1612
1613                Ahoy!
1614            "}
1615        );
1616
1617        assert_eq!(
1618            thread.latest_token_usage(),
1619            Some(acp_thread::TokenUsage {
1620                used_tokens: 40_000 + 20_000,
1621                max_tokens: 1_000_000,
1622            })
1623        );
1624    });
1625}
1626
1627#[gpui::test]
1628async fn test_truncate_second_message(cx: &mut TestAppContext) {
1629    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1630    let fake_model = model.as_fake();
1631
1632    thread
1633        .update(cx, |thread, cx| {
1634            thread.send(UserMessageId::new(), ["Message 1"], cx)
1635        })
1636        .unwrap();
1637    cx.run_until_parked();
1638    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1639    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640        language_model::TokenUsage {
1641            input_tokens: 32_000,
1642            output_tokens: 16_000,
1643            cache_creation_input_tokens: 0,
1644            cache_read_input_tokens: 0,
1645        },
1646    ));
1647    fake_model.end_last_completion_stream();
1648    cx.run_until_parked();
1649
1650    let assert_first_message_state = |cx: &mut TestAppContext| {
1651        thread.clone().read_with(cx, |thread, _| {
1652            assert_eq!(
1653                thread.to_markdown(),
1654                indoc! {"
1655                    ## User
1656
1657                    Message 1
1658
1659                    ## Assistant
1660
1661                    Message 1 response
1662                "}
1663            );
1664
1665            assert_eq!(
1666                thread.latest_token_usage(),
1667                Some(acp_thread::TokenUsage {
1668                    used_tokens: 32_000 + 16_000,
1669                    max_tokens: 1_000_000,
1670                })
1671            );
1672        });
1673    };
1674
1675    assert_first_message_state(cx);
1676
1677    let second_message_id = UserMessageId::new();
1678    thread
1679        .update(cx, |thread, cx| {
1680            thread.send(second_message_id.clone(), ["Message 2"], cx)
1681        })
1682        .unwrap();
1683    cx.run_until_parked();
1684
1685    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1686    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1687        language_model::TokenUsage {
1688            input_tokens: 40_000,
1689            output_tokens: 20_000,
1690            cache_creation_input_tokens: 0,
1691            cache_read_input_tokens: 0,
1692        },
1693    ));
1694    fake_model.end_last_completion_stream();
1695    cx.run_until_parked();
1696
1697    thread.read_with(cx, |thread, _| {
1698        assert_eq!(
1699            thread.to_markdown(),
1700            indoc! {"
1701                ## User
1702
1703                Message 1
1704
1705                ## Assistant
1706
1707                Message 1 response
1708
1709                ## User
1710
1711                Message 2
1712
1713                ## Assistant
1714
1715                Message 2 response
1716            "}
1717        );
1718
1719        assert_eq!(
1720            thread.latest_token_usage(),
1721            Some(acp_thread::TokenUsage {
1722                used_tokens: 40_000 + 20_000,
1723                max_tokens: 1_000_000,
1724            })
1725        );
1726    });
1727
1728    thread
1729        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1730        .unwrap();
1731    cx.run_until_parked();
1732
1733    assert_first_message_state(cx);
1734}
1735
1736#[gpui::test]
1737async fn test_title_generation(cx: &mut TestAppContext) {
1738    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1739    let fake_model = model.as_fake();
1740
1741    let summary_model = Arc::new(FakeLanguageModel::default());
1742    thread.update(cx, |thread, cx| {
1743        thread.set_summarization_model(Some(summary_model.clone()), cx)
1744    });
1745
1746    let send = thread
1747        .update(cx, |thread, cx| {
1748            thread.send(UserMessageId::new(), ["Hello"], cx)
1749        })
1750        .unwrap();
1751    cx.run_until_parked();
1752
1753    fake_model.send_last_completion_stream_text_chunk("Hey!");
1754    fake_model.end_last_completion_stream();
1755    cx.run_until_parked();
1756    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1757
1758    // Ensure the summary model has been invoked to generate a title.
1759    summary_model.send_last_completion_stream_text_chunk("Hello ");
1760    summary_model.send_last_completion_stream_text_chunk("world\nG");
1761    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1762    summary_model.end_last_completion_stream();
1763    send.collect::<Vec<_>>().await;
1764    cx.run_until_parked();
1765    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1766
1767    // Send another message, ensuring no title is generated this time.
1768    let send = thread
1769        .update(cx, |thread, cx| {
1770            thread.send(UserMessageId::new(), ["Hello again"], cx)
1771        })
1772        .unwrap();
1773    cx.run_until_parked();
1774    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1775    fake_model.end_last_completion_stream();
1776    cx.run_until_parked();
1777    assert_eq!(summary_model.pending_completions(), Vec::new());
1778    send.collect::<Vec<_>>().await;
1779    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1780}
1781
1782#[gpui::test]
1783async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1784    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1785    let fake_model = model.as_fake();
1786
1787    let _events = thread
1788        .update(cx, |thread, cx| {
1789            thread.add_tool(ToolRequiringPermission);
1790            thread.add_tool(EchoTool);
1791            thread.send(UserMessageId::new(), ["Hey!"], cx)
1792        })
1793        .unwrap();
1794    cx.run_until_parked();
1795
1796    let permission_tool_use = LanguageModelToolUse {
1797        id: "tool_id_1".into(),
1798        name: ToolRequiringPermission::name().into(),
1799        raw_input: "{}".into(),
1800        input: json!({}),
1801        is_input_complete: true,
1802        thought_signature: None,
1803    };
1804    let echo_tool_use = LanguageModelToolUse {
1805        id: "tool_id_2".into(),
1806        name: EchoTool::name().into(),
1807        raw_input: json!({"text": "test"}).to_string(),
1808        input: json!({"text": "test"}),
1809        is_input_complete: true,
1810        thought_signature: None,
1811    };
1812    fake_model.send_last_completion_stream_text_chunk("Hi!");
1813    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1814        permission_tool_use,
1815    ));
1816    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1817        echo_tool_use.clone(),
1818    ));
1819    fake_model.end_last_completion_stream();
1820    cx.run_until_parked();
1821
1822    // Ensure pending tools are skipped when building a request.
1823    let request = thread
1824        .read_with(cx, |thread, cx| {
1825            thread.build_completion_request(CompletionIntent::EditFile, cx)
1826        })
1827        .unwrap();
1828    assert_eq!(
1829        request.messages[1..],
1830        vec![
1831            LanguageModelRequestMessage {
1832                role: Role::User,
1833                content: vec!["Hey!".into()],
1834                cache: true
1835            },
1836            LanguageModelRequestMessage {
1837                role: Role::Assistant,
1838                content: vec![
1839                    MessageContent::Text("Hi!".into()),
1840                    MessageContent::ToolUse(echo_tool_use.clone())
1841                ],
1842                cache: false
1843            },
1844            LanguageModelRequestMessage {
1845                role: Role::User,
1846                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1847                    tool_use_id: echo_tool_use.id.clone(),
1848                    tool_name: echo_tool_use.name,
1849                    is_error: false,
1850                    content: "test".into(),
1851                    output: Some("test".into())
1852                })],
1853                cache: false
1854            },
1855        ],
1856    );
1857}
1858
1859#[gpui::test]
1860async fn test_agent_connection(cx: &mut TestAppContext) {
1861    cx.update(settings::init);
1862    let templates = Templates::new();
1863
1864    // Initialize language model system with test provider
1865    cx.update(|cx| {
1866        gpui_tokio::init(cx);
1867        client::init_settings(cx);
1868
1869        let http_client = FakeHttpClient::with_404_response();
1870        let clock = Arc::new(clock::FakeSystemClock::new());
1871        let client = Client::new(clock, http_client, cx);
1872        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1873        language_model::init(client.clone(), cx);
1874        language_models::init(user_store, client.clone(), cx);
1875        Project::init_settings(cx);
1876        LanguageModelRegistry::test(cx);
1877        agent_settings::init(cx);
1878    });
1879    cx.executor().forbid_parking();
1880
1881    // Create a project for new_thread
1882    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1883    fake_fs.insert_tree(path!("/test"), json!({})).await;
1884    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1885    let cwd = Path::new("/test");
1886    let text_thread_store =
1887        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1888    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1889
1890    // Create agent and connection
1891    let agent = NativeAgent::new(
1892        project.clone(),
1893        history_store,
1894        templates.clone(),
1895        None,
1896        fake_fs.clone(),
1897        &mut cx.to_async(),
1898    )
1899    .await
1900    .unwrap();
1901    let connection = NativeAgentConnection(agent.clone());
1902
1903    // Create a thread using new_thread
1904    let connection_rc = Rc::new(connection.clone());
1905    let acp_thread = cx
1906        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1907        .await
1908        .expect("new_thread should succeed");
1909
1910    // Get the session_id from the AcpThread
1911    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1912
1913    // Test model_selector returns Some
1914    let selector_opt = connection.model_selector(&session_id);
1915    assert!(
1916        selector_opt.is_some(),
1917        "agent should always support ModelSelector"
1918    );
1919    let selector = selector_opt.unwrap();
1920
1921    // Test list_models
1922    let listed_models = cx
1923        .update(|cx| selector.list_models(cx))
1924        .await
1925        .expect("list_models should succeed");
1926    let AgentModelList::Grouped(listed_models) = listed_models else {
1927        panic!("Unexpected model list type");
1928    };
1929    assert!(!listed_models.is_empty(), "should have at least one model");
1930    assert_eq!(
1931        listed_models[&AgentModelGroupName("Fake".into())][0]
1932            .id
1933            .0
1934            .as_ref(),
1935        "fake/fake"
1936    );
1937
1938    // Test selected_model returns the default
1939    let model = cx
1940        .update(|cx| selector.selected_model(cx))
1941        .await
1942        .expect("selected_model should succeed");
1943    let model = cx
1944        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1945        .unwrap();
1946    let model = model.as_fake();
1947    assert_eq!(model.id().0, "fake", "should return default model");
1948
1949    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1950    cx.run_until_parked();
1951    model.send_last_completion_stream_text_chunk("def");
1952    cx.run_until_parked();
1953    acp_thread.read_with(cx, |thread, cx| {
1954        assert_eq!(
1955            thread.to_markdown(cx),
1956            indoc! {"
1957                ## User
1958
1959                abc
1960
1961                ## Assistant
1962
1963                def
1964
1965            "}
1966        )
1967    });
1968
1969    // Test cancel
1970    cx.update(|cx| connection.cancel(&session_id, cx));
1971    request.await.expect("prompt should fail gracefully");
1972
1973    // Ensure that dropping the ACP thread causes the native thread to be
1974    // dropped as well.
1975    cx.update(|_| drop(acp_thread));
1976    let result = cx
1977        .update(|cx| {
1978            connection.prompt(
1979                Some(acp_thread::UserMessageId::new()),
1980                acp::PromptRequest {
1981                    session_id: session_id.clone(),
1982                    prompt: vec!["ghi".into()],
1983                    meta: None,
1984                },
1985                cx,
1986            )
1987        })
1988        .await;
1989    assert_eq!(
1990        result.as_ref().unwrap_err().to_string(),
1991        "Session not found",
1992        "unexpected result: {:?}",
1993        result
1994    );
1995}
1996
1997#[gpui::test]
1998async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1999    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2000    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2001    let fake_model = model.as_fake();
2002
2003    let mut events = thread
2004        .update(cx, |thread, cx| {
2005            thread.send(UserMessageId::new(), ["Think"], cx)
2006        })
2007        .unwrap();
2008    cx.run_until_parked();
2009
2010    // Simulate streaming partial input.
2011    let input = json!({});
2012    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2013        LanguageModelToolUse {
2014            id: "1".into(),
2015            name: ThinkingTool::name().into(),
2016            raw_input: input.to_string(),
2017            input,
2018            is_input_complete: false,
2019            thought_signature: None,
2020        },
2021    ));
2022
2023    // Input streaming completed
2024    let input = json!({ "content": "Thinking hard!" });
2025    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2026        LanguageModelToolUse {
2027            id: "1".into(),
2028            name: "thinking".into(),
2029            raw_input: input.to_string(),
2030            input,
2031            is_input_complete: true,
2032            thought_signature: None,
2033        },
2034    ));
2035    fake_model.end_last_completion_stream();
2036    cx.run_until_parked();
2037
2038    let tool_call = expect_tool_call(&mut events).await;
2039    assert_eq!(
2040        tool_call,
2041        acp::ToolCall {
2042            id: acp::ToolCallId("1".into()),
2043            title: "Thinking".into(),
2044            kind: acp::ToolKind::Think,
2045            status: acp::ToolCallStatus::Pending,
2046            content: vec![],
2047            locations: vec![],
2048            raw_input: Some(json!({})),
2049            raw_output: None,
2050            meta: Some(json!({ "tool_name": "thinking" })),
2051        }
2052    );
2053    let update = expect_tool_call_update_fields(&mut events).await;
2054    assert_eq!(
2055        update,
2056        acp::ToolCallUpdate {
2057            id: acp::ToolCallId("1".into()),
2058            fields: acp::ToolCallUpdateFields {
2059                title: Some("Thinking".into()),
2060                kind: Some(acp::ToolKind::Think),
2061                raw_input: Some(json!({ "content": "Thinking hard!" })),
2062                ..Default::default()
2063            },
2064            meta: None,
2065        }
2066    );
2067    let update = expect_tool_call_update_fields(&mut events).await;
2068    assert_eq!(
2069        update,
2070        acp::ToolCallUpdate {
2071            id: acp::ToolCallId("1".into()),
2072            fields: acp::ToolCallUpdateFields {
2073                status: Some(acp::ToolCallStatus::InProgress),
2074                ..Default::default()
2075            },
2076            meta: None,
2077        }
2078    );
2079    let update = expect_tool_call_update_fields(&mut events).await;
2080    assert_eq!(
2081        update,
2082        acp::ToolCallUpdate {
2083            id: acp::ToolCallId("1".into()),
2084            fields: acp::ToolCallUpdateFields {
2085                content: Some(vec!["Thinking hard!".into()]),
2086                ..Default::default()
2087            },
2088            meta: None,
2089        }
2090    );
2091    let update = expect_tool_call_update_fields(&mut events).await;
2092    assert_eq!(
2093        update,
2094        acp::ToolCallUpdate {
2095            id: acp::ToolCallId("1".into()),
2096            fields: acp::ToolCallUpdateFields {
2097                status: Some(acp::ToolCallStatus::Completed),
2098                raw_output: Some("Finished thinking.".into()),
2099                ..Default::default()
2100            },
2101            meta: None,
2102        }
2103    );
2104}
2105
2106#[gpui::test]
2107async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2108    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2109    let fake_model = model.as_fake();
2110
2111    let mut events = thread
2112        .update(cx, |thread, cx| {
2113            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2114            thread.send(UserMessageId::new(), ["Hello!"], cx)
2115        })
2116        .unwrap();
2117    cx.run_until_parked();
2118
2119    fake_model.send_last_completion_stream_text_chunk("Hey!");
2120    fake_model.end_last_completion_stream();
2121
2122    let mut retry_events = Vec::new();
2123    while let Some(Ok(event)) = events.next().await {
2124        match event {
2125            ThreadEvent::Retry(retry_status) => {
2126                retry_events.push(retry_status);
2127            }
2128            ThreadEvent::Stop(..) => break,
2129            _ => {}
2130        }
2131    }
2132
2133    assert_eq!(retry_events.len(), 0);
2134    thread.read_with(cx, |thread, _cx| {
2135        assert_eq!(
2136            thread.to_markdown(),
2137            indoc! {"
2138                ## User
2139
2140                Hello!
2141
2142                ## Assistant
2143
2144                Hey!
2145            "}
2146        )
2147    });
2148}
2149
2150#[gpui::test]
2151async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2152    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2153    let fake_model = model.as_fake();
2154
2155    let mut events = thread
2156        .update(cx, |thread, cx| {
2157            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2158            thread.send(UserMessageId::new(), ["Hello!"], cx)
2159        })
2160        .unwrap();
2161    cx.run_until_parked();
2162
2163    fake_model.send_last_completion_stream_text_chunk("Hey,");
2164    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2165        provider: LanguageModelProviderName::new("Anthropic"),
2166        retry_after: Some(Duration::from_secs(3)),
2167    });
2168    fake_model.end_last_completion_stream();
2169
2170    cx.executor().advance_clock(Duration::from_secs(3));
2171    cx.run_until_parked();
2172
2173    fake_model.send_last_completion_stream_text_chunk("there!");
2174    fake_model.end_last_completion_stream();
2175    cx.run_until_parked();
2176
2177    let mut retry_events = Vec::new();
2178    while let Some(Ok(event)) = events.next().await {
2179        match event {
2180            ThreadEvent::Retry(retry_status) => {
2181                retry_events.push(retry_status);
2182            }
2183            ThreadEvent::Stop(..) => break,
2184            _ => {}
2185        }
2186    }
2187
2188    assert_eq!(retry_events.len(), 1);
2189    assert!(matches!(
2190        retry_events[0],
2191        acp_thread::RetryStatus { attempt: 1, .. }
2192    ));
2193    thread.read_with(cx, |thread, _cx| {
2194        assert_eq!(
2195            thread.to_markdown(),
2196            indoc! {"
2197                ## User
2198
2199                Hello!
2200
2201                ## Assistant
2202
2203                Hey,
2204
2205                [resume]
2206
2207                ## Assistant
2208
2209                there!
2210            "}
2211        )
2212    });
2213}
2214
2215#[gpui::test]
2216async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2217    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2218    let fake_model = model.as_fake();
2219
2220    let events = thread
2221        .update(cx, |thread, cx| {
2222            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2223            thread.add_tool(EchoTool);
2224            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2225        })
2226        .unwrap();
2227    cx.run_until_parked();
2228
2229    let tool_use_1 = LanguageModelToolUse {
2230        id: "tool_1".into(),
2231        name: EchoTool::name().into(),
2232        raw_input: json!({"text": "test"}).to_string(),
2233        input: json!({"text": "test"}),
2234        is_input_complete: true,
2235        thought_signature: None,
2236    };
2237    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2238        tool_use_1.clone(),
2239    ));
2240    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2241        provider: LanguageModelProviderName::new("Anthropic"),
2242        retry_after: Some(Duration::from_secs(3)),
2243    });
2244    fake_model.end_last_completion_stream();
2245
2246    cx.executor().advance_clock(Duration::from_secs(3));
2247    let completion = fake_model.pending_completions().pop().unwrap();
2248    assert_eq!(
2249        completion.messages[1..],
2250        vec![
2251            LanguageModelRequestMessage {
2252                role: Role::User,
2253                content: vec!["Call the echo tool!".into()],
2254                cache: false
2255            },
2256            LanguageModelRequestMessage {
2257                role: Role::Assistant,
2258                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2259                cache: false
2260            },
2261            LanguageModelRequestMessage {
2262                role: Role::User,
2263                content: vec![language_model::MessageContent::ToolResult(
2264                    LanguageModelToolResult {
2265                        tool_use_id: tool_use_1.id.clone(),
2266                        tool_name: tool_use_1.name.clone(),
2267                        is_error: false,
2268                        content: "test".into(),
2269                        output: Some("test".into())
2270                    }
2271                )],
2272                cache: true
2273            },
2274        ]
2275    );
2276
2277    fake_model.send_last_completion_stream_text_chunk("Done");
2278    fake_model.end_last_completion_stream();
2279    cx.run_until_parked();
2280    events.collect::<Vec<_>>().await;
2281    thread.read_with(cx, |thread, _cx| {
2282        assert_eq!(
2283            thread.last_message(),
2284            Some(Message::Agent(AgentMessage {
2285                content: vec![AgentMessageContent::Text("Done".into())],
2286                tool_results: IndexMap::default()
2287            }))
2288        );
2289    })
2290}
2291
2292#[gpui::test]
2293async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2294    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2295    let fake_model = model.as_fake();
2296
2297    let mut events = thread
2298        .update(cx, |thread, cx| {
2299            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2300            thread.send(UserMessageId::new(), ["Hello!"], cx)
2301        })
2302        .unwrap();
2303    cx.run_until_parked();
2304
2305    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2306        fake_model.send_last_completion_stream_error(
2307            LanguageModelCompletionError::ServerOverloaded {
2308                provider: LanguageModelProviderName::new("Anthropic"),
2309                retry_after: Some(Duration::from_secs(3)),
2310            },
2311        );
2312        fake_model.end_last_completion_stream();
2313        cx.executor().advance_clock(Duration::from_secs(3));
2314        cx.run_until_parked();
2315    }
2316
2317    let mut errors = Vec::new();
2318    let mut retry_events = Vec::new();
2319    while let Some(event) = events.next().await {
2320        match event {
2321            Ok(ThreadEvent::Retry(retry_status)) => {
2322                retry_events.push(retry_status);
2323            }
2324            Ok(ThreadEvent::Stop(..)) => break,
2325            Err(error) => errors.push(error),
2326            _ => {}
2327        }
2328    }
2329
2330    assert_eq!(
2331        retry_events.len(),
2332        crate::thread::MAX_RETRY_ATTEMPTS as usize
2333    );
2334    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2335        assert_eq!(retry_events[i].attempt, i + 1);
2336    }
2337    assert_eq!(errors.len(), 1);
2338    let error = errors[0]
2339        .downcast_ref::<LanguageModelCompletionError>()
2340        .unwrap();
2341    assert!(matches!(
2342        error,
2343        LanguageModelCompletionError::ServerOverloaded { .. }
2344    ));
2345}
2346
2347/// Filters out the stop events for asserting against in tests
2348fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2349    result_events
2350        .into_iter()
2351        .filter_map(|event| match event.unwrap() {
2352            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2353            _ => None,
2354        })
2355        .collect()
2356}
2357
2358struct ThreadTest {
2359    model: Arc<dyn LanguageModel>,
2360    thread: Entity<Thread>,
2361    project_context: Entity<ProjectContext>,
2362    context_server_store: Entity<ContextServerStore>,
2363    fs: Arc<FakeFs>,
2364}
2365
2366enum TestModel {
2367    Sonnet4,
2368    Fake,
2369}
2370
2371impl TestModel {
2372    fn id(&self) -> LanguageModelId {
2373        match self {
2374            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2375            TestModel::Fake => unreachable!(),
2376        }
2377    }
2378}
2379
2380async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2381    cx.executor().allow_parking();
2382
2383    let fs = FakeFs::new(cx.background_executor.clone());
2384    fs.create_dir(paths::settings_file().parent().unwrap())
2385        .await
2386        .unwrap();
2387    fs.insert_file(
2388        paths::settings_file(),
2389        json!({
2390            "agent": {
2391                "default_profile": "test-profile",
2392                "profiles": {
2393                    "test-profile": {
2394                        "name": "Test Profile",
2395                        "tools": {
2396                            EchoTool::name(): true,
2397                            DelayTool::name(): true,
2398                            WordListTool::name(): true,
2399                            ToolRequiringPermission::name(): true,
2400                            InfiniteTool::name(): true,
2401                            ThinkingTool::name(): true,
2402                        }
2403                    }
2404                }
2405            }
2406        })
2407        .to_string()
2408        .into_bytes(),
2409    )
2410    .await;
2411
2412    cx.update(|cx| {
2413        settings::init(cx);
2414        Project::init_settings(cx);
2415        agent_settings::init(cx);
2416
2417        match model {
2418            TestModel::Fake => {}
2419            TestModel::Sonnet4 => {
2420                gpui_tokio::init(cx);
2421                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2422                cx.set_http_client(Arc::new(http_client));
2423                client::init_settings(cx);
2424                let client = Client::production(cx);
2425                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2426                language_model::init(client.clone(), cx);
2427                language_models::init(user_store, client.clone(), cx);
2428            }
2429        };
2430
2431        watch_settings(fs.clone(), cx);
2432    });
2433
2434    let templates = Templates::new();
2435
2436    fs.insert_tree(path!("/test"), json!({})).await;
2437    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2438
2439    let model = cx
2440        .update(|cx| {
2441            if let TestModel::Fake = model {
2442                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2443            } else {
2444                let model_id = model.id();
2445                let models = LanguageModelRegistry::read_global(cx);
2446                let model = models
2447                    .available_models(cx)
2448                    .find(|model| model.id() == model_id)
2449                    .unwrap();
2450
2451                let provider = models.provider(&model.provider_id()).unwrap();
2452                let authenticated = provider.authenticate(cx);
2453
2454                cx.spawn(async move |_cx| {
2455                    authenticated.await.unwrap();
2456                    model
2457                })
2458            }
2459        })
2460        .await;
2461
2462    let project_context = cx.new(|_cx| ProjectContext::default());
2463    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2464    let context_server_registry =
2465        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2466    let thread = cx.new(|cx| {
2467        Thread::new(
2468            project,
2469            project_context.clone(),
2470            context_server_registry,
2471            templates,
2472            Some(model.clone()),
2473            cx,
2474        )
2475    });
2476    ThreadTest {
2477        model,
2478        thread,
2479        project_context,
2480        context_server_store,
2481        fs,
2482    }
2483}
2484
2485#[cfg(test)]
2486#[ctor::ctor]
2487fn init_logger() {
2488    if std::env::var("RUST_LOG").is_ok() {
2489        env_logger::init();
2490    }
2491}
2492
2493fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2494    let fs = fs.clone();
2495    cx.spawn({
2496        async move |cx| {
2497            let mut new_settings_content_rx = settings::watch_config_file(
2498                cx.background_executor(),
2499                fs,
2500                paths::settings_file().clone(),
2501            );
2502
2503            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2504                cx.update(|cx| {
2505                    SettingsStore::update_global(cx, |settings, cx| {
2506                        settings.set_user_settings(&new_settings_content, cx)
2507                    })
2508                })
2509                .ok();
2510            }
2511        }
2512    })
2513    .detach();
2514}
2515
2516fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2517    completion
2518        .tools
2519        .iter()
2520        .map(|tool| tool.name.clone())
2521        .collect()
2522}
2523
2524fn setup_context_server(
2525    name: &'static str,
2526    tools: Vec<context_server::types::Tool>,
2527    context_server_store: &Entity<ContextServerStore>,
2528    cx: &mut TestAppContext,
2529) -> mpsc::UnboundedReceiver<(
2530    context_server::types::CallToolParams,
2531    oneshot::Sender<context_server::types::CallToolResponse>,
2532)> {
2533    cx.update(|cx| {
2534        let mut settings = ProjectSettings::get_global(cx).clone();
2535        settings.context_servers.insert(
2536            name.into(),
2537            project::project_settings::ContextServerSettings::Custom {
2538                enabled: true,
2539                command: ContextServerCommand {
2540                    path: "somebinary".into(),
2541                    args: Vec::new(),
2542                    env: None,
2543                    timeout: None,
2544                },
2545            },
2546        );
2547        ProjectSettings::override_global(settings, cx);
2548    });
2549
2550    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2551    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2552        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2553            context_server::types::InitializeResponse {
2554                protocol_version: context_server::types::ProtocolVersion(
2555                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2556                ),
2557                server_info: context_server::types::Implementation {
2558                    name: name.into(),
2559                    version: "1.0.0".to_string(),
2560                },
2561                capabilities: context_server::types::ServerCapabilities {
2562                    tools: Some(context_server::types::ToolsCapabilities {
2563                        list_changed: Some(true),
2564                    }),
2565                    ..Default::default()
2566                },
2567                meta: None,
2568            }
2569        })
2570        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2571            let tools = tools.clone();
2572            async move {
2573                context_server::types::ListToolsResponse {
2574                    tools,
2575                    next_cursor: None,
2576                    meta: None,
2577                }
2578            }
2579        })
2580        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2581            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2582            async move {
2583                let (response_tx, response_rx) = oneshot::channel();
2584                mcp_tool_calls_tx
2585                    .unbounded_send((params, response_tx))
2586                    .unwrap();
2587                response_rx.await.unwrap()
2588            }
2589        });
2590    context_server_store.update(cx, |store, cx| {
2591        store.start_server(
2592            Arc::new(ContextServer::new(
2593                ContextServerId(name.into()),
2594                Arc::new(fake_transport),
2595            )),
2596            cx,
2597        );
2598    });
2599    cx.run_until_parked();
2600    mcp_tool_calls_rx
2601}