mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true
 219        }]
 220    );
 221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 222        "Response to Message 1".into(),
 223    ));
 224    fake_model.end_last_completion_stream();
 225    cx.run_until_parked();
 226
 227    // Send another user message and verify only the latest is cached
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Message 2"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let completion = fake_model.pending_completions().pop().unwrap();
 236    assert_eq!(
 237        completion.messages[1..],
 238        vec![
 239            LanguageModelRequestMessage {
 240                role: Role::User,
 241                content: vec!["Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::Assistant,
 246                content: vec!["Response to Message 1".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::User,
 251                content: vec!["Message 2".into()],
 252                cache: true
 253            }
 254        ]
 255    );
 256    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 257        "Response to Message 2".into(),
 258    ));
 259    fake_model.end_last_completion_stream();
 260    cx.run_until_parked();
 261
 262    // Simulate a tool call and verify that the latest tool result is cached
 263    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 264    thread
 265        .update(cx, |thread, cx| {
 266            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 267        })
 268        .unwrap();
 269    cx.run_until_parked();
 270
 271    let tool_use = LanguageModelToolUse {
 272        id: "tool_1".into(),
 273        name: EchoTool::name().into(),
 274        raw_input: json!({"text": "test"}).to_string(),
 275        input: json!({"text": "test"}),
 276        is_input_complete: true,
 277        thought_signature: None,
 278    };
 279    fake_model
 280        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 281    fake_model.end_last_completion_stream();
 282    cx.run_until_parked();
 283
 284    let completion = fake_model.pending_completions().pop().unwrap();
 285    let tool_result = LanguageModelToolResult {
 286        tool_use_id: "tool_1".into(),
 287        tool_name: EchoTool::name().into(),
 288        is_error: false,
 289        content: "test".into(),
 290        output: Some("test".into()),
 291    };
 292    assert_eq!(
 293        completion.messages[1..],
 294        vec![
 295            LanguageModelRequestMessage {
 296                role: Role::User,
 297                content: vec!["Message 1".into()],
 298                cache: false
 299            },
 300            LanguageModelRequestMessage {
 301                role: Role::Assistant,
 302                content: vec!["Response to Message 1".into()],
 303                cache: false
 304            },
 305            LanguageModelRequestMessage {
 306                role: Role::User,
 307                content: vec!["Message 2".into()],
 308                cache: false
 309            },
 310            LanguageModelRequestMessage {
 311                role: Role::Assistant,
 312                content: vec!["Response to Message 2".into()],
 313                cache: false
 314            },
 315            LanguageModelRequestMessage {
 316                role: Role::User,
 317                content: vec!["Use the echo tool".into()],
 318                cache: false
 319            },
 320            LanguageModelRequestMessage {
 321                role: Role::Assistant,
 322                content: vec![MessageContent::ToolUse(tool_use)],
 323                cache: false
 324            },
 325            LanguageModelRequestMessage {
 326                role: Role::User,
 327                content: vec![MessageContent::ToolResult(tool_result)],
 328                cache: true
 329            }
 330        ]
 331    );
 332}
 333
 334#[gpui::test]
 335#[cfg_attr(not(feature = "e2e"), ignore)]
 336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 337    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 338
 339    // Test a tool call that's likely to complete *before* streaming stops.
 340    let events = thread
 341        .update(cx, |thread, cx| {
 342            thread.add_tool(EchoTool);
 343            thread.send(
 344                UserMessageId::new(),
 345                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 346                cx,
 347            )
 348        })
 349        .unwrap()
 350        .collect()
 351        .await;
 352    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 353
 354    // Test a tool calls that's likely to complete *after* streaming stops.
 355    let events = thread
 356        .update(cx, |thread, cx| {
 357            thread.remove_tool(&EchoTool::name());
 358            thread.add_tool(DelayTool);
 359            thread.send(
 360                UserMessageId::new(),
 361                [
 362                    "Now call the delay tool with 200ms.",
 363                    "When the timer goes off, then you echo the output of the tool.",
 364                ],
 365                cx,
 366            )
 367        })
 368        .unwrap()
 369        .collect()
 370        .await;
 371    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 372    thread.update(cx, |thread, _cx| {
 373        assert!(
 374            thread
 375                .last_message()
 376                .unwrap()
 377                .as_agent_message()
 378                .unwrap()
 379                .content
 380                .iter()
 381                .any(|content| {
 382                    if let AgentMessageContent::Text(text) = content {
 383                        text.contains("Ding")
 384                    } else {
 385                        false
 386                    }
 387                }),
 388            "{}",
 389            thread.to_markdown()
 390        );
 391    });
 392}
 393
 394#[gpui::test]
 395#[cfg_attr(not(feature = "e2e"), ignore)]
 396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 397    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 398
 399    // Test a tool call that's likely to complete *before* streaming stops.
 400    let mut events = thread
 401        .update(cx, |thread, cx| {
 402            thread.add_tool(WordListTool);
 403            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 404        })
 405        .unwrap();
 406
 407    let mut saw_partial_tool_use = false;
 408    while let Some(event) = events.next().await {
 409        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 410            thread.update(cx, |thread, _cx| {
 411                // Look for a tool use in the thread's last message
 412                let message = thread.last_message().unwrap();
 413                let agent_message = message.as_agent_message().unwrap();
 414                let last_content = agent_message.content.last().unwrap();
 415                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 416                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 417                    if tool_call.status == acp::ToolCallStatus::Pending {
 418                        if !last_tool_use.is_input_complete
 419                            && last_tool_use.input.get("g").is_none()
 420                        {
 421                            saw_partial_tool_use = true;
 422                        }
 423                    } else {
 424                        last_tool_use
 425                            .input
 426                            .get("a")
 427                            .expect("'a' has streamed because input is now complete");
 428                        last_tool_use
 429                            .input
 430                            .get("g")
 431                            .expect("'g' has streamed because input is now complete");
 432                    }
 433                } else {
 434                    panic!("last content should be a tool use");
 435                }
 436            });
 437        }
 438    }
 439
 440    assert!(
 441        saw_partial_tool_use,
 442        "should see at least one partially streamed tool use in the history"
 443    );
 444}
 445
 446#[gpui::test]
 447async fn test_tool_authorization(cx: &mut TestAppContext) {
 448    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 449    let fake_model = model.as_fake();
 450
 451    let mut events = thread
 452        .update(cx, |thread, cx| {
 453            thread.add_tool(ToolRequiringPermission);
 454            thread.send(UserMessageId::new(), ["abc"], cx)
 455        })
 456        .unwrap();
 457    cx.run_until_parked();
 458    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 459        LanguageModelToolUse {
 460            id: "tool_id_1".into(),
 461            name: ToolRequiringPermission::name().into(),
 462            raw_input: "{}".into(),
 463            input: json!({}),
 464            is_input_complete: true,
 465            thought_signature: None,
 466        },
 467    ));
 468    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 469        LanguageModelToolUse {
 470            id: "tool_id_2".into(),
 471            name: ToolRequiringPermission::name().into(),
 472            raw_input: "{}".into(),
 473            input: json!({}),
 474            is_input_complete: true,
 475            thought_signature: None,
 476        },
 477    ));
 478    fake_model.end_last_completion_stream();
 479    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 480    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 481
 482    // Approve the first
 483    tool_call_auth_1
 484        .response
 485        .send(tool_call_auth_1.options[1].id.clone())
 486        .unwrap();
 487    cx.run_until_parked();
 488
 489    // Reject the second
 490    tool_call_auth_2
 491        .response
 492        .send(tool_call_auth_1.options[2].id.clone())
 493        .unwrap();
 494    cx.run_until_parked();
 495
 496    let completion = fake_model.pending_completions().pop().unwrap();
 497    let message = completion.messages.last().unwrap();
 498    assert_eq!(
 499        message.content,
 500        vec![
 501            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 502                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 503                tool_name: ToolRequiringPermission::name().into(),
 504                is_error: false,
 505                content: "Allowed".into(),
 506                output: Some("Allowed".into())
 507            }),
 508            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 509                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 510                tool_name: ToolRequiringPermission::name().into(),
 511                is_error: true,
 512                content: "Permission to run tool denied by user".into(),
 513                output: Some("Permission to run tool denied by user".into())
 514            })
 515        ]
 516    );
 517
 518    // Simulate yet another tool call.
 519    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 520        LanguageModelToolUse {
 521            id: "tool_id_3".into(),
 522            name: ToolRequiringPermission::name().into(),
 523            raw_input: "{}".into(),
 524            input: json!({}),
 525            is_input_complete: true,
 526            thought_signature: None,
 527        },
 528    ));
 529    fake_model.end_last_completion_stream();
 530
 531    // Respond by always allowing tools.
 532    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 533    tool_call_auth_3
 534        .response
 535        .send(tool_call_auth_3.options[0].id.clone())
 536        .unwrap();
 537    cx.run_until_parked();
 538    let completion = fake_model.pending_completions().pop().unwrap();
 539    let message = completion.messages.last().unwrap();
 540    assert_eq!(
 541        message.content,
 542        vec![language_model::MessageContent::ToolResult(
 543            LanguageModelToolResult {
 544                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 545                tool_name: ToolRequiringPermission::name().into(),
 546                is_error: false,
 547                content: "Allowed".into(),
 548                output: Some("Allowed".into())
 549            }
 550        )]
 551    );
 552
 553    // Simulate a final tool call, ensuring we don't trigger authorization.
 554    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 555        LanguageModelToolUse {
 556            id: "tool_id_4".into(),
 557            name: ToolRequiringPermission::name().into(),
 558            raw_input: "{}".into(),
 559            input: json!({}),
 560            is_input_complete: true,
 561            thought_signature: None,
 562        },
 563    ));
 564    fake_model.end_last_completion_stream();
 565    cx.run_until_parked();
 566    let completion = fake_model.pending_completions().pop().unwrap();
 567    let message = completion.messages.last().unwrap();
 568    assert_eq!(
 569        message.content,
 570        vec![language_model::MessageContent::ToolResult(
 571            LanguageModelToolResult {
 572                tool_use_id: "tool_id_4".into(),
 573                tool_name: ToolRequiringPermission::name().into(),
 574                is_error: false,
 575                content: "Allowed".into(),
 576                output: Some("Allowed".into())
 577            }
 578        )]
 579    );
 580}
 581
 582#[gpui::test]
 583async fn test_tool_hallucination(cx: &mut TestAppContext) {
 584    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 585    let fake_model = model.as_fake();
 586
 587    let mut events = thread
 588        .update(cx, |thread, cx| {
 589            thread.send(UserMessageId::new(), ["abc"], cx)
 590        })
 591        .unwrap();
 592    cx.run_until_parked();
 593    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 594        LanguageModelToolUse {
 595            id: "tool_id_1".into(),
 596            name: "nonexistent_tool".into(),
 597            raw_input: "{}".into(),
 598            input: json!({}),
 599            is_input_complete: true,
 600            thought_signature: None,
 601        },
 602    ));
 603    fake_model.end_last_completion_stream();
 604
 605    let tool_call = expect_tool_call(&mut events).await;
 606    assert_eq!(tool_call.title, "nonexistent_tool");
 607    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 608    let update = expect_tool_call_update_fields(&mut events).await;
 609    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 610}
 611
 612#[gpui::test]
 613async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 614    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 615    let fake_model = model.as_fake();
 616
 617    let events = thread
 618        .update(cx, |thread, cx| {
 619            thread.add_tool(EchoTool);
 620            thread.send(UserMessageId::new(), ["abc"], cx)
 621        })
 622        .unwrap();
 623    cx.run_until_parked();
 624    let tool_use = LanguageModelToolUse {
 625        id: "tool_id_1".into(),
 626        name: EchoTool::name().into(),
 627        raw_input: "{}".into(),
 628        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 629        is_input_complete: true,
 630        thought_signature: None,
 631    };
 632    fake_model
 633        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 634    fake_model.end_last_completion_stream();
 635
 636    cx.run_until_parked();
 637    let completion = fake_model.pending_completions().pop().unwrap();
 638    let tool_result = LanguageModelToolResult {
 639        tool_use_id: "tool_id_1".into(),
 640        tool_name: EchoTool::name().into(),
 641        is_error: false,
 642        content: "def".into(),
 643        output: Some("def".into()),
 644    };
 645    assert_eq!(
 646        completion.messages[1..],
 647        vec![
 648            LanguageModelRequestMessage {
 649                role: Role::User,
 650                content: vec!["abc".into()],
 651                cache: false
 652            },
 653            LanguageModelRequestMessage {
 654                role: Role::Assistant,
 655                content: vec![MessageContent::ToolUse(tool_use.clone())],
 656                cache: false
 657            },
 658            LanguageModelRequestMessage {
 659                role: Role::User,
 660                content: vec![MessageContent::ToolResult(tool_result.clone())],
 661                cache: true
 662            },
 663        ]
 664    );
 665
 666    // Simulate reaching tool use limit.
 667    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 668        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 669    ));
 670    fake_model.end_last_completion_stream();
 671    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 672    assert!(
 673        last_event
 674            .unwrap_err()
 675            .is::<language_model::ToolUseLimitReachedError>()
 676    );
 677
 678    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 679    cx.run_until_parked();
 680    let completion = fake_model.pending_completions().pop().unwrap();
 681    assert_eq!(
 682        completion.messages[1..],
 683        vec![
 684            LanguageModelRequestMessage {
 685                role: Role::User,
 686                content: vec!["abc".into()],
 687                cache: false
 688            },
 689            LanguageModelRequestMessage {
 690                role: Role::Assistant,
 691                content: vec![MessageContent::ToolUse(tool_use)],
 692                cache: false
 693            },
 694            LanguageModelRequestMessage {
 695                role: Role::User,
 696                content: vec![MessageContent::ToolResult(tool_result)],
 697                cache: false
 698            },
 699            LanguageModelRequestMessage {
 700                role: Role::User,
 701                content: vec!["Continue where you left off".into()],
 702                cache: true
 703            }
 704        ]
 705    );
 706
 707    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 708    fake_model.end_last_completion_stream();
 709    events.collect::<Vec<_>>().await;
 710    thread.read_with(cx, |thread, _cx| {
 711        assert_eq!(
 712            thread.last_message().unwrap().to_markdown(),
 713            indoc! {"
 714                ## Assistant
 715
 716                Done
 717            "}
 718        )
 719    });
 720}
 721
 722#[gpui::test]
 723async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 724    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 725    let fake_model = model.as_fake();
 726
 727    let events = thread
 728        .update(cx, |thread, cx| {
 729            thread.add_tool(EchoTool);
 730            thread.send(UserMessageId::new(), ["abc"], cx)
 731        })
 732        .unwrap();
 733    cx.run_until_parked();
 734
 735    let tool_use = LanguageModelToolUse {
 736        id: "tool_id_1".into(),
 737        name: EchoTool::name().into(),
 738        raw_input: "{}".into(),
 739        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 740        is_input_complete: true,
 741        thought_signature: None,
 742    };
 743    let tool_result = LanguageModelToolResult {
 744        tool_use_id: "tool_id_1".into(),
 745        tool_name: EchoTool::name().into(),
 746        is_error: false,
 747        content: "def".into(),
 748        output: Some("def".into()),
 749    };
 750    fake_model
 751        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 752    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 753        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 754    ));
 755    fake_model.end_last_completion_stream();
 756    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 757    assert!(
 758        last_event
 759            .unwrap_err()
 760            .is::<language_model::ToolUseLimitReachedError>()
 761    );
 762
 763    thread
 764        .update(cx, |thread, cx| {
 765            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 766        })
 767        .unwrap();
 768    cx.run_until_parked();
 769    let completion = fake_model.pending_completions().pop().unwrap();
 770    assert_eq!(
 771        completion.messages[1..],
 772        vec![
 773            LanguageModelRequestMessage {
 774                role: Role::User,
 775                content: vec!["abc".into()],
 776                cache: false
 777            },
 778            LanguageModelRequestMessage {
 779                role: Role::Assistant,
 780                content: vec![MessageContent::ToolUse(tool_use)],
 781                cache: false
 782            },
 783            LanguageModelRequestMessage {
 784                role: Role::User,
 785                content: vec![MessageContent::ToolResult(tool_result)],
 786                cache: false
 787            },
 788            LanguageModelRequestMessage {
 789                role: Role::User,
 790                content: vec!["ghi".into()],
 791                cache: true
 792            }
 793        ]
 794    );
 795}
 796
 797async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 798    let event = events
 799        .next()
 800        .await
 801        .expect("no tool call authorization event received")
 802        .unwrap();
 803    match event {
 804        ThreadEvent::ToolCall(tool_call) => tool_call,
 805        event => {
 806            panic!("Unexpected event {event:?}");
 807        }
 808    }
 809}
 810
 811async fn expect_tool_call_update_fields(
 812    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 813) -> acp::ToolCallUpdate {
 814    let event = events
 815        .next()
 816        .await
 817        .expect("no tool call authorization event received")
 818        .unwrap();
 819    match event {
 820        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 821        event => {
 822            panic!("Unexpected event {event:?}");
 823        }
 824    }
 825}
 826
 827async fn next_tool_call_authorization(
 828    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 829) -> ToolCallAuthorization {
 830    loop {
 831        let event = events
 832            .next()
 833            .await
 834            .expect("no tool call authorization event received")
 835            .unwrap();
 836        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 837            let permission_kinds = tool_call_authorization
 838                .options
 839                .iter()
 840                .map(|o| o.kind)
 841                .collect::<Vec<_>>();
 842            assert_eq!(
 843                permission_kinds,
 844                vec![
 845                    acp::PermissionOptionKind::AllowAlways,
 846                    acp::PermissionOptionKind::AllowOnce,
 847                    acp::PermissionOptionKind::RejectOnce,
 848                ]
 849            );
 850            return tool_call_authorization;
 851        }
 852    }
 853}
 854
 855#[gpui::test]
 856#[cfg_attr(not(feature = "e2e"), ignore)]
 857async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 858    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 859
 860    // Test concurrent tool calls with different delay times
 861    let events = thread
 862        .update(cx, |thread, cx| {
 863            thread.add_tool(DelayTool);
 864            thread.send(
 865                UserMessageId::new(),
 866                [
 867                    "Call the delay tool twice in the same message.",
 868                    "Once with 100ms. Once with 300ms.",
 869                    "When both timers are complete, describe the outputs.",
 870                ],
 871                cx,
 872            )
 873        })
 874        .unwrap()
 875        .collect()
 876        .await;
 877
 878    let stop_reasons = stop_events(events);
 879    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 880
 881    thread.update(cx, |thread, _cx| {
 882        let last_message = thread.last_message().unwrap();
 883        let agent_message = last_message.as_agent_message().unwrap();
 884        let text = agent_message
 885            .content
 886            .iter()
 887            .filter_map(|content| {
 888                if let AgentMessageContent::Text(text) = content {
 889                    Some(text.as_str())
 890                } else {
 891                    None
 892                }
 893            })
 894            .collect::<String>();
 895
 896        assert!(text.contains("Ding"));
 897    });
 898}
 899
 900#[gpui::test]
 901async fn test_profiles(cx: &mut TestAppContext) {
 902    let ThreadTest {
 903        model, thread, fs, ..
 904    } = setup(cx, TestModel::Fake).await;
 905    let fake_model = model.as_fake();
 906
 907    thread.update(cx, |thread, _cx| {
 908        thread.add_tool(DelayTool);
 909        thread.add_tool(EchoTool);
 910        thread.add_tool(InfiniteTool);
 911    });
 912
 913    // Override profiles and wait for settings to be loaded.
 914    fs.insert_file(
 915        paths::settings_file(),
 916        json!({
 917            "agent": {
 918                "profiles": {
 919                    "test-1": {
 920                        "name": "Test Profile 1",
 921                        "tools": {
 922                            EchoTool::name(): true,
 923                            DelayTool::name(): true,
 924                        }
 925                    },
 926                    "test-2": {
 927                        "name": "Test Profile 2",
 928                        "tools": {
 929                            InfiniteTool::name(): true,
 930                        }
 931                    }
 932                }
 933            }
 934        })
 935        .to_string()
 936        .into_bytes(),
 937    )
 938    .await;
 939    cx.run_until_parked();
 940
 941    // Test that test-1 profile (default) has echo and delay tools
 942    thread
 943        .update(cx, |thread, cx| {
 944            thread.set_profile(AgentProfileId("test-1".into()), cx);
 945            thread.send(UserMessageId::new(), ["test"], cx)
 946        })
 947        .unwrap();
 948    cx.run_until_parked();
 949
 950    let mut pending_completions = fake_model.pending_completions();
 951    assert_eq!(pending_completions.len(), 1);
 952    let completion = pending_completions.pop().unwrap();
 953    let tool_names: Vec<String> = completion
 954        .tools
 955        .iter()
 956        .map(|tool| tool.name.clone())
 957        .collect();
 958    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 959    fake_model.end_last_completion_stream();
 960
 961    // Switch to test-2 profile, and verify that it has only the infinite tool.
 962    thread
 963        .update(cx, |thread, cx| {
 964            thread.set_profile(AgentProfileId("test-2".into()), cx);
 965            thread.send(UserMessageId::new(), ["test2"], cx)
 966        })
 967        .unwrap();
 968    cx.run_until_parked();
 969    let mut pending_completions = fake_model.pending_completions();
 970    assert_eq!(pending_completions.len(), 1);
 971    let completion = pending_completions.pop().unwrap();
 972    let tool_names: Vec<String> = completion
 973        .tools
 974        .iter()
 975        .map(|tool| tool.name.clone())
 976        .collect();
 977    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 978}
 979
 980#[gpui::test]
 981async fn test_mcp_tools(cx: &mut TestAppContext) {
 982    let ThreadTest {
 983        model,
 984        thread,
 985        context_server_store,
 986        fs,
 987        ..
 988    } = setup(cx, TestModel::Fake).await;
 989    let fake_model = model.as_fake();
 990
 991    // Override profiles and wait for settings to be loaded.
 992    fs.insert_file(
 993        paths::settings_file(),
 994        json!({
 995            "agent": {
 996                "always_allow_tool_actions": true,
 997                "profiles": {
 998                    "test": {
 999                        "name": "Test Profile",
1000                        "enable_all_context_servers": true,
1001                        "tools": {
1002                            EchoTool::name(): true,
1003                        }
1004                    },
1005                }
1006            }
1007        })
1008        .to_string()
1009        .into_bytes(),
1010    )
1011    .await;
1012    cx.run_until_parked();
1013    thread.update(cx, |thread, cx| {
1014        thread.set_profile(AgentProfileId("test".into()), cx)
1015    });
1016
1017    let mut mcp_tool_calls = setup_context_server(
1018        "test_server",
1019        vec![context_server::types::Tool {
1020            name: "echo".into(),
1021            description: None,
1022            input_schema: serde_json::to_value(EchoTool::input_schema(
1023                LanguageModelToolSchemaFormat::JsonSchema,
1024            ))
1025            .unwrap(),
1026            output_schema: None,
1027            annotations: None,
1028        }],
1029        &context_server_store,
1030        cx,
1031    );
1032
1033    let events = thread.update(cx, |thread, cx| {
1034        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1035    });
1036    cx.run_until_parked();
1037
1038    // Simulate the model calling the MCP tool.
1039    let completion = fake_model.pending_completions().pop().unwrap();
1040    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1041    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1042        LanguageModelToolUse {
1043            id: "tool_1".into(),
1044            name: "echo".into(),
1045            raw_input: json!({"text": "test"}).to_string(),
1046            input: json!({"text": "test"}),
1047            is_input_complete: true,
1048            thought_signature: None,
1049        },
1050    ));
1051    fake_model.end_last_completion_stream();
1052    cx.run_until_parked();
1053
1054    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1055    assert_eq!(tool_call_params.name, "echo");
1056    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1057    tool_call_response
1058        .send(context_server::types::CallToolResponse {
1059            content: vec![context_server::types::ToolResponseContent::Text {
1060                text: "test".into(),
1061            }],
1062            is_error: None,
1063            meta: None,
1064            structured_content: None,
1065        })
1066        .unwrap();
1067    cx.run_until_parked();
1068
1069    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1070    fake_model.send_last_completion_stream_text_chunk("Done!");
1071    fake_model.end_last_completion_stream();
1072    events.collect::<Vec<_>>().await;
1073
1074    // Send again after adding the echo tool, ensuring the name collision is resolved.
1075    let events = thread.update(cx, |thread, cx| {
1076        thread.add_tool(EchoTool);
1077        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1078    });
1079    cx.run_until_parked();
1080    let completion = fake_model.pending_completions().pop().unwrap();
1081    assert_eq!(
1082        tool_names_for_completion(&completion),
1083        vec!["echo", "test_server_echo"]
1084    );
1085    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086        LanguageModelToolUse {
1087            id: "tool_2".into(),
1088            name: "test_server_echo".into(),
1089            raw_input: json!({"text": "mcp"}).to_string(),
1090            input: json!({"text": "mcp"}),
1091            is_input_complete: true,
1092            thought_signature: None,
1093        },
1094    ));
1095    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1096        LanguageModelToolUse {
1097            id: "tool_3".into(),
1098            name: "echo".into(),
1099            raw_input: json!({"text": "native"}).to_string(),
1100            input: json!({"text": "native"}),
1101            is_input_complete: true,
1102            thought_signature: None,
1103        },
1104    ));
1105    fake_model.end_last_completion_stream();
1106    cx.run_until_parked();
1107
1108    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1109    assert_eq!(tool_call_params.name, "echo");
1110    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1111    tool_call_response
1112        .send(context_server::types::CallToolResponse {
1113            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1114            is_error: None,
1115            meta: None,
1116            structured_content: None,
1117        })
1118        .unwrap();
1119    cx.run_until_parked();
1120
1121    // Ensure the tool results were inserted with the correct names.
1122    let completion = fake_model.pending_completions().pop().unwrap();
1123    assert_eq!(
1124        completion.messages.last().unwrap().content,
1125        vec![
1126            MessageContent::ToolResult(LanguageModelToolResult {
1127                tool_use_id: "tool_3".into(),
1128                tool_name: "echo".into(),
1129                is_error: false,
1130                content: "native".into(),
1131                output: Some("native".into()),
1132            },),
1133            MessageContent::ToolResult(LanguageModelToolResult {
1134                tool_use_id: "tool_2".into(),
1135                tool_name: "test_server_echo".into(),
1136                is_error: false,
1137                content: "mcp".into(),
1138                output: Some("mcp".into()),
1139            },),
1140        ]
1141    );
1142    fake_model.end_last_completion_stream();
1143    events.collect::<Vec<_>>().await;
1144}
1145
1146#[gpui::test]
1147async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1148    let ThreadTest {
1149        model,
1150        thread,
1151        context_server_store,
1152        fs,
1153        ..
1154    } = setup(cx, TestModel::Fake).await;
1155    let fake_model = model.as_fake();
1156
1157    // Set up a profile with all tools enabled
1158    fs.insert_file(
1159        paths::settings_file(),
1160        json!({
1161            "agent": {
1162                "profiles": {
1163                    "test": {
1164                        "name": "Test Profile",
1165                        "enable_all_context_servers": true,
1166                        "tools": {
1167                            EchoTool::name(): true,
1168                            DelayTool::name(): true,
1169                            WordListTool::name(): true,
1170                            ToolRequiringPermission::name(): true,
1171                            InfiniteTool::name(): true,
1172                        }
1173                    },
1174                }
1175            }
1176        })
1177        .to_string()
1178        .into_bytes(),
1179    )
1180    .await;
1181    cx.run_until_parked();
1182
1183    thread.update(cx, |thread, cx| {
1184        thread.set_profile(AgentProfileId("test".into()), cx);
1185        thread.add_tool(EchoTool);
1186        thread.add_tool(DelayTool);
1187        thread.add_tool(WordListTool);
1188        thread.add_tool(ToolRequiringPermission);
1189        thread.add_tool(InfiniteTool);
1190    });
1191
1192    // Set up multiple context servers with some overlapping tool names
1193    let _server1_calls = setup_context_server(
1194        "xxx",
1195        vec![
1196            context_server::types::Tool {
1197                name: "echo".into(), // Conflicts with native EchoTool
1198                description: None,
1199                input_schema: serde_json::to_value(EchoTool::input_schema(
1200                    LanguageModelToolSchemaFormat::JsonSchema,
1201                ))
1202                .unwrap(),
1203                output_schema: None,
1204                annotations: None,
1205            },
1206            context_server::types::Tool {
1207                name: "unique_tool_1".into(),
1208                description: None,
1209                input_schema: json!({"type": "object", "properties": {}}),
1210                output_schema: None,
1211                annotations: None,
1212            },
1213        ],
1214        &context_server_store,
1215        cx,
1216    );
1217
1218    let _server2_calls = setup_context_server(
1219        "yyy",
1220        vec![
1221            context_server::types::Tool {
1222                name: "echo".into(), // Also conflicts with native EchoTool
1223                description: None,
1224                input_schema: serde_json::to_value(EchoTool::input_schema(
1225                    LanguageModelToolSchemaFormat::JsonSchema,
1226                ))
1227                .unwrap(),
1228                output_schema: None,
1229                annotations: None,
1230            },
1231            context_server::types::Tool {
1232                name: "unique_tool_2".into(),
1233                description: None,
1234                input_schema: json!({"type": "object", "properties": {}}),
1235                output_schema: None,
1236                annotations: None,
1237            },
1238            context_server::types::Tool {
1239                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1240                description: None,
1241                input_schema: json!({"type": "object", "properties": {}}),
1242                output_schema: None,
1243                annotations: None,
1244            },
1245            context_server::types::Tool {
1246                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1247                description: None,
1248                input_schema: json!({"type": "object", "properties": {}}),
1249                output_schema: None,
1250                annotations: None,
1251            },
1252        ],
1253        &context_server_store,
1254        cx,
1255    );
1256    let _server3_calls = setup_context_server(
1257        "zzz",
1258        vec![
1259            context_server::types::Tool {
1260                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1261                description: None,
1262                input_schema: json!({"type": "object", "properties": {}}),
1263                output_schema: None,
1264                annotations: None,
1265            },
1266            context_server::types::Tool {
1267                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1268                description: None,
1269                input_schema: json!({"type": "object", "properties": {}}),
1270                output_schema: None,
1271                annotations: None,
1272            },
1273            context_server::types::Tool {
1274                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1275                description: None,
1276                input_schema: json!({"type": "object", "properties": {}}),
1277                output_schema: None,
1278                annotations: None,
1279            },
1280        ],
1281        &context_server_store,
1282        cx,
1283    );
1284
1285    thread
1286        .update(cx, |thread, cx| {
1287            thread.send(UserMessageId::new(), ["Go"], cx)
1288        })
1289        .unwrap();
1290    cx.run_until_parked();
1291    let completion = fake_model.pending_completions().pop().unwrap();
1292    assert_eq!(
1293        tool_names_for_completion(&completion),
1294        vec![
1295            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1296            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1297            "delay",
1298            "echo",
1299            "infinite",
1300            "tool_requiring_permission",
1301            "unique_tool_1",
1302            "unique_tool_2",
1303            "word_list",
1304            "xxx_echo",
1305            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1306            "yyy_echo",
1307            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1308        ]
1309    );
1310}
1311
1312#[gpui::test]
1313#[cfg_attr(not(feature = "e2e"), ignore)]
1314async fn test_cancellation(cx: &mut TestAppContext) {
1315    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1316
1317    let mut events = thread
1318        .update(cx, |thread, cx| {
1319            thread.add_tool(InfiniteTool);
1320            thread.add_tool(EchoTool);
1321            thread.send(
1322                UserMessageId::new(),
1323                ["Call the echo tool, then call the infinite tool, then explain their output"],
1324                cx,
1325            )
1326        })
1327        .unwrap();
1328
1329    // Wait until both tools are called.
1330    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1331    let mut echo_id = None;
1332    let mut echo_completed = false;
1333    while let Some(event) = events.next().await {
1334        match event.unwrap() {
1335            ThreadEvent::ToolCall(tool_call) => {
1336                assert_eq!(tool_call.title, expected_tools.remove(0));
1337                if tool_call.title == "Echo" {
1338                    echo_id = Some(tool_call.id);
1339                }
1340            }
1341            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1342                acp::ToolCallUpdate {
1343                    id,
1344                    fields:
1345                        acp::ToolCallUpdateFields {
1346                            status: Some(acp::ToolCallStatus::Completed),
1347                            ..
1348                        },
1349                    meta: None,
1350                },
1351            )) if Some(&id) == echo_id.as_ref() => {
1352                echo_completed = true;
1353            }
1354            _ => {}
1355        }
1356
1357        if expected_tools.is_empty() && echo_completed {
1358            break;
1359        }
1360    }
1361
1362    // Cancel the current send and ensure that the event stream is closed, even
1363    // if one of the tools is still running.
1364    thread.update(cx, |thread, cx| thread.cancel(cx));
1365    let events = events.collect::<Vec<_>>().await;
1366    let last_event = events.last();
1367    assert!(
1368        matches!(
1369            last_event,
1370            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1371        ),
1372        "unexpected event {last_event:?}"
1373    );
1374
1375    // Ensure we can still send a new message after cancellation.
1376    let events = thread
1377        .update(cx, |thread, cx| {
1378            thread.send(
1379                UserMessageId::new(),
1380                ["Testing: reply with 'Hello' then stop."],
1381                cx,
1382            )
1383        })
1384        .unwrap()
1385        .collect::<Vec<_>>()
1386        .await;
1387    thread.update(cx, |thread, _cx| {
1388        let message = thread.last_message().unwrap();
1389        let agent_message = message.as_agent_message().unwrap();
1390        assert_eq!(
1391            agent_message.content,
1392            vec![AgentMessageContent::Text("Hello".to_string())]
1393        );
1394    });
1395    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1396}
1397
1398#[gpui::test]
1399async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1400    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1401    let fake_model = model.as_fake();
1402
1403    let events_1 = thread
1404        .update(cx, |thread, cx| {
1405            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1406        })
1407        .unwrap();
1408    cx.run_until_parked();
1409    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1410    cx.run_until_parked();
1411
1412    let events_2 = thread
1413        .update(cx, |thread, cx| {
1414            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1415        })
1416        .unwrap();
1417    cx.run_until_parked();
1418    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1419    fake_model
1420        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1421    fake_model.end_last_completion_stream();
1422
1423    let events_1 = events_1.collect::<Vec<_>>().await;
1424    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1425    let events_2 = events_2.collect::<Vec<_>>().await;
1426    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1427}
1428
1429#[gpui::test]
1430async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1431    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1432    let fake_model = model.as_fake();
1433
1434    let events_1 = thread
1435        .update(cx, |thread, cx| {
1436            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1437        })
1438        .unwrap();
1439    cx.run_until_parked();
1440    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1441    fake_model
1442        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1443    fake_model.end_last_completion_stream();
1444    let events_1 = events_1.collect::<Vec<_>>().await;
1445
1446    let events_2 = thread
1447        .update(cx, |thread, cx| {
1448            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1449        })
1450        .unwrap();
1451    cx.run_until_parked();
1452    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1453    fake_model
1454        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1455    fake_model.end_last_completion_stream();
1456    let events_2 = events_2.collect::<Vec<_>>().await;
1457
1458    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1459    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1460}
1461
1462#[gpui::test]
1463async fn test_refusal(cx: &mut TestAppContext) {
1464    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1465    let fake_model = model.as_fake();
1466
1467    let events = thread
1468        .update(cx, |thread, cx| {
1469            thread.send(UserMessageId::new(), ["Hello"], cx)
1470        })
1471        .unwrap();
1472    cx.run_until_parked();
1473    thread.read_with(cx, |thread, _| {
1474        assert_eq!(
1475            thread.to_markdown(),
1476            indoc! {"
1477                ## User
1478
1479                Hello
1480            "}
1481        );
1482    });
1483
1484    fake_model.send_last_completion_stream_text_chunk("Hey!");
1485    cx.run_until_parked();
1486    thread.read_with(cx, |thread, _| {
1487        assert_eq!(
1488            thread.to_markdown(),
1489            indoc! {"
1490                ## User
1491
1492                Hello
1493
1494                ## Assistant
1495
1496                Hey!
1497            "}
1498        );
1499    });
1500
1501    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1502    fake_model
1503        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1504    let events = events.collect::<Vec<_>>().await;
1505    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1506    thread.read_with(cx, |thread, _| {
1507        assert_eq!(thread.to_markdown(), "");
1508    });
1509}
1510
1511#[gpui::test]
1512async fn test_truncate_first_message(cx: &mut TestAppContext) {
1513    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1514    let fake_model = model.as_fake();
1515
1516    let message_id = UserMessageId::new();
1517    thread
1518        .update(cx, |thread, cx| {
1519            thread.send(message_id.clone(), ["Hello"], cx)
1520        })
1521        .unwrap();
1522    cx.run_until_parked();
1523    thread.read_with(cx, |thread, _| {
1524        assert_eq!(
1525            thread.to_markdown(),
1526            indoc! {"
1527                ## User
1528
1529                Hello
1530            "}
1531        );
1532        assert_eq!(thread.latest_token_usage(), None);
1533    });
1534
1535    fake_model.send_last_completion_stream_text_chunk("Hey!");
1536    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1537        language_model::TokenUsage {
1538            input_tokens: 32_000,
1539            output_tokens: 16_000,
1540            cache_creation_input_tokens: 0,
1541            cache_read_input_tokens: 0,
1542        },
1543    ));
1544    cx.run_until_parked();
1545    thread.read_with(cx, |thread, _| {
1546        assert_eq!(
1547            thread.to_markdown(),
1548            indoc! {"
1549                ## User
1550
1551                Hello
1552
1553                ## Assistant
1554
1555                Hey!
1556            "}
1557        );
1558        assert_eq!(
1559            thread.latest_token_usage(),
1560            Some(acp_thread::TokenUsage {
1561                used_tokens: 32_000 + 16_000,
1562                max_tokens: 1_000_000,
1563            })
1564        );
1565    });
1566
1567    thread
1568        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1569        .unwrap();
1570    cx.run_until_parked();
1571    thread.read_with(cx, |thread, _| {
1572        assert_eq!(thread.to_markdown(), "");
1573        assert_eq!(thread.latest_token_usage(), None);
1574    });
1575
1576    // Ensure we can still send a new message after truncation.
1577    thread
1578        .update(cx, |thread, cx| {
1579            thread.send(UserMessageId::new(), ["Hi"], cx)
1580        })
1581        .unwrap();
1582    thread.update(cx, |thread, _cx| {
1583        assert_eq!(
1584            thread.to_markdown(),
1585            indoc! {"
1586                ## User
1587
1588                Hi
1589            "}
1590        );
1591    });
1592    cx.run_until_parked();
1593    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1594    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1595        language_model::TokenUsage {
1596            input_tokens: 40_000,
1597            output_tokens: 20_000,
1598            cache_creation_input_tokens: 0,
1599            cache_read_input_tokens: 0,
1600        },
1601    ));
1602    cx.run_until_parked();
1603    thread.read_with(cx, |thread, _| {
1604        assert_eq!(
1605            thread.to_markdown(),
1606            indoc! {"
1607                ## User
1608
1609                Hi
1610
1611                ## Assistant
1612
1613                Ahoy!
1614            "}
1615        );
1616
1617        assert_eq!(
1618            thread.latest_token_usage(),
1619            Some(acp_thread::TokenUsage {
1620                used_tokens: 40_000 + 20_000,
1621                max_tokens: 1_000_000,
1622            })
1623        );
1624    });
1625}
1626
1627#[gpui::test]
1628async fn test_truncate_second_message(cx: &mut TestAppContext) {
1629    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1630    let fake_model = model.as_fake();
1631
1632    thread
1633        .update(cx, |thread, cx| {
1634            thread.send(UserMessageId::new(), ["Message 1"], cx)
1635        })
1636        .unwrap();
1637    cx.run_until_parked();
1638    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1639    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640        language_model::TokenUsage {
1641            input_tokens: 32_000,
1642            output_tokens: 16_000,
1643            cache_creation_input_tokens: 0,
1644            cache_read_input_tokens: 0,
1645        },
1646    ));
1647    fake_model.end_last_completion_stream();
1648    cx.run_until_parked();
1649
1650    let assert_first_message_state = |cx: &mut TestAppContext| {
1651        thread.clone().read_with(cx, |thread, _| {
1652            assert_eq!(
1653                thread.to_markdown(),
1654                indoc! {"
1655                    ## User
1656
1657                    Message 1
1658
1659                    ## Assistant
1660
1661                    Message 1 response
1662                "}
1663            );
1664
1665            assert_eq!(
1666                thread.latest_token_usage(),
1667                Some(acp_thread::TokenUsage {
1668                    used_tokens: 32_000 + 16_000,
1669                    max_tokens: 1_000_000,
1670                })
1671            );
1672        });
1673    };
1674
1675    assert_first_message_state(cx);
1676
1677    let second_message_id = UserMessageId::new();
1678    thread
1679        .update(cx, |thread, cx| {
1680            thread.send(second_message_id.clone(), ["Message 2"], cx)
1681        })
1682        .unwrap();
1683    cx.run_until_parked();
1684
1685    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1686    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1687        language_model::TokenUsage {
1688            input_tokens: 40_000,
1689            output_tokens: 20_000,
1690            cache_creation_input_tokens: 0,
1691            cache_read_input_tokens: 0,
1692        },
1693    ));
1694    fake_model.end_last_completion_stream();
1695    cx.run_until_parked();
1696
1697    thread.read_with(cx, |thread, _| {
1698        assert_eq!(
1699            thread.to_markdown(),
1700            indoc! {"
1701                ## User
1702
1703                Message 1
1704
1705                ## Assistant
1706
1707                Message 1 response
1708
1709                ## User
1710
1711                Message 2
1712
1713                ## Assistant
1714
1715                Message 2 response
1716            "}
1717        );
1718
1719        assert_eq!(
1720            thread.latest_token_usage(),
1721            Some(acp_thread::TokenUsage {
1722                used_tokens: 40_000 + 20_000,
1723                max_tokens: 1_000_000,
1724            })
1725        );
1726    });
1727
1728    thread
1729        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1730        .unwrap();
1731    cx.run_until_parked();
1732
1733    assert_first_message_state(cx);
1734}
1735
1736#[gpui::test]
1737async fn test_title_generation(cx: &mut TestAppContext) {
1738    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1739    let fake_model = model.as_fake();
1740
1741    let summary_model = Arc::new(FakeLanguageModel::default());
1742    thread.update(cx, |thread, cx| {
1743        thread.set_summarization_model(Some(summary_model.clone()), cx)
1744    });
1745
1746    let send = thread
1747        .update(cx, |thread, cx| {
1748            thread.send(UserMessageId::new(), ["Hello"], cx)
1749        })
1750        .unwrap();
1751    cx.run_until_parked();
1752
1753    fake_model.send_last_completion_stream_text_chunk("Hey!");
1754    fake_model.end_last_completion_stream();
1755    cx.run_until_parked();
1756    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1757
1758    // Ensure the summary model has been invoked to generate a title.
1759    summary_model.send_last_completion_stream_text_chunk("Hello ");
1760    summary_model.send_last_completion_stream_text_chunk("world\nG");
1761    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1762    summary_model.end_last_completion_stream();
1763    send.collect::<Vec<_>>().await;
1764    cx.run_until_parked();
1765    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1766
1767    // Send another message, ensuring no title is generated this time.
1768    let send = thread
1769        .update(cx, |thread, cx| {
1770            thread.send(UserMessageId::new(), ["Hello again"], cx)
1771        })
1772        .unwrap();
1773    cx.run_until_parked();
1774    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1775    fake_model.end_last_completion_stream();
1776    cx.run_until_parked();
1777    assert_eq!(summary_model.pending_completions(), Vec::new());
1778    send.collect::<Vec<_>>().await;
1779    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1780}
1781
1782#[gpui::test]
1783async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1784    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1785    let fake_model = model.as_fake();
1786
1787    let _events = thread
1788        .update(cx, |thread, cx| {
1789            thread.add_tool(ToolRequiringPermission);
1790            thread.add_tool(EchoTool);
1791            thread.send(UserMessageId::new(), ["Hey!"], cx)
1792        })
1793        .unwrap();
1794    cx.run_until_parked();
1795
1796    let permission_tool_use = LanguageModelToolUse {
1797        id: "tool_id_1".into(),
1798        name: ToolRequiringPermission::name().into(),
1799        raw_input: "{}".into(),
1800        input: json!({}),
1801        is_input_complete: true,
1802        thought_signature: None,
1803    };
1804    let echo_tool_use = LanguageModelToolUse {
1805        id: "tool_id_2".into(),
1806        name: EchoTool::name().into(),
1807        raw_input: json!({"text": "test"}).to_string(),
1808        input: json!({"text": "test"}),
1809        is_input_complete: true,
1810        thought_signature: None,
1811    };
1812    fake_model.send_last_completion_stream_text_chunk("Hi!");
1813    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1814        permission_tool_use,
1815    ));
1816    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1817        echo_tool_use.clone(),
1818    ));
1819    fake_model.end_last_completion_stream();
1820    cx.run_until_parked();
1821
1822    // Ensure pending tools are skipped when building a request.
1823    let request = thread
1824        .read_with(cx, |thread, cx| {
1825            thread.build_completion_request(CompletionIntent::EditFile, cx)
1826        })
1827        .unwrap();
1828    assert_eq!(
1829        request.messages[1..],
1830        vec![
1831            LanguageModelRequestMessage {
1832                role: Role::User,
1833                content: vec!["Hey!".into()],
1834                cache: true
1835            },
1836            LanguageModelRequestMessage {
1837                role: Role::Assistant,
1838                content: vec![
1839                    MessageContent::Text("Hi!".into()),
1840                    MessageContent::ToolUse(echo_tool_use.clone())
1841                ],
1842                cache: false
1843            },
1844            LanguageModelRequestMessage {
1845                role: Role::User,
1846                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1847                    tool_use_id: echo_tool_use.id.clone(),
1848                    tool_name: echo_tool_use.name,
1849                    is_error: false,
1850                    content: "test".into(),
1851                    output: Some("test".into())
1852                })],
1853                cache: false
1854            },
1855        ],
1856    );
1857}
1858
1859#[gpui::test]
1860async fn test_agent_connection(cx: &mut TestAppContext) {
1861    cx.update(settings::init);
1862    let templates = Templates::new();
1863
1864    // Initialize language model system with test provider
1865    cx.update(|cx| {
1866        gpui_tokio::init(cx);
1867
1868        let http_client = FakeHttpClient::with_404_response();
1869        let clock = Arc::new(clock::FakeSystemClock::new());
1870        let client = Client::new(clock, http_client, cx);
1871        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1872        language_model::init(client.clone(), cx);
1873        language_models::init(user_store, client.clone(), cx);
1874        LanguageModelRegistry::test(cx);
1875    });
1876    cx.executor().forbid_parking();
1877
1878    // Create a project for new_thread
1879    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1880    fake_fs.insert_tree(path!("/test"), json!({})).await;
1881    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1882    let cwd = Path::new("/test");
1883    let text_thread_store =
1884        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1885    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1886
1887    // Create agent and connection
1888    let agent = NativeAgent::new(
1889        project.clone(),
1890        history_store,
1891        templates.clone(),
1892        None,
1893        fake_fs.clone(),
1894        &mut cx.to_async(),
1895    )
1896    .await
1897    .unwrap();
1898    let connection = NativeAgentConnection(agent.clone());
1899
1900    // Create a thread using new_thread
1901    let connection_rc = Rc::new(connection.clone());
1902    let acp_thread = cx
1903        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1904        .await
1905        .expect("new_thread should succeed");
1906
1907    // Get the session_id from the AcpThread
1908    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1909
1910    // Test model_selector returns Some
1911    let selector_opt = connection.model_selector(&session_id);
1912    assert!(
1913        selector_opt.is_some(),
1914        "agent should always support ModelSelector"
1915    );
1916    let selector = selector_opt.unwrap();
1917
1918    // Test list_models
1919    let listed_models = cx
1920        .update(|cx| selector.list_models(cx))
1921        .await
1922        .expect("list_models should succeed");
1923    let AgentModelList::Grouped(listed_models) = listed_models else {
1924        panic!("Unexpected model list type");
1925    };
1926    assert!(!listed_models.is_empty(), "should have at least one model");
1927    assert_eq!(
1928        listed_models[&AgentModelGroupName("Fake".into())][0]
1929            .id
1930            .0
1931            .as_ref(),
1932        "fake/fake"
1933    );
1934
1935    // Test selected_model returns the default
1936    let model = cx
1937        .update(|cx| selector.selected_model(cx))
1938        .await
1939        .expect("selected_model should succeed");
1940    let model = cx
1941        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1942        .unwrap();
1943    let model = model.as_fake();
1944    assert_eq!(model.id().0, "fake", "should return default model");
1945
1946    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1947    cx.run_until_parked();
1948    model.send_last_completion_stream_text_chunk("def");
1949    cx.run_until_parked();
1950    acp_thread.read_with(cx, |thread, cx| {
1951        assert_eq!(
1952            thread.to_markdown(cx),
1953            indoc! {"
1954                ## User
1955
1956                abc
1957
1958                ## Assistant
1959
1960                def
1961
1962            "}
1963        )
1964    });
1965
1966    // Test cancel
1967    cx.update(|cx| connection.cancel(&session_id, cx));
1968    request.await.expect("prompt should fail gracefully");
1969
1970    // Ensure that dropping the ACP thread causes the native thread to be
1971    // dropped as well.
1972    cx.update(|_| drop(acp_thread));
1973    let result = cx
1974        .update(|cx| {
1975            connection.prompt(
1976                Some(acp_thread::UserMessageId::new()),
1977                acp::PromptRequest {
1978                    session_id: session_id.clone(),
1979                    prompt: vec!["ghi".into()],
1980                    meta: None,
1981                },
1982                cx,
1983            )
1984        })
1985        .await;
1986    assert_eq!(
1987        result.as_ref().unwrap_err().to_string(),
1988        "Session not found",
1989        "unexpected result: {:?}",
1990        result
1991    );
1992}
1993
1994#[gpui::test]
1995async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1996    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1997    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1998    let fake_model = model.as_fake();
1999
2000    let mut events = thread
2001        .update(cx, |thread, cx| {
2002            thread.send(UserMessageId::new(), ["Think"], cx)
2003        })
2004        .unwrap();
2005    cx.run_until_parked();
2006
2007    // Simulate streaming partial input.
2008    let input = json!({});
2009    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2010        LanguageModelToolUse {
2011            id: "1".into(),
2012            name: ThinkingTool::name().into(),
2013            raw_input: input.to_string(),
2014            input,
2015            is_input_complete: false,
2016            thought_signature: None,
2017        },
2018    ));
2019
2020    // Input streaming completed
2021    let input = json!({ "content": "Thinking hard!" });
2022    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2023        LanguageModelToolUse {
2024            id: "1".into(),
2025            name: "thinking".into(),
2026            raw_input: input.to_string(),
2027            input,
2028            is_input_complete: true,
2029            thought_signature: None,
2030        },
2031    ));
2032    fake_model.end_last_completion_stream();
2033    cx.run_until_parked();
2034
2035    let tool_call = expect_tool_call(&mut events).await;
2036    assert_eq!(
2037        tool_call,
2038        acp::ToolCall {
2039            id: acp::ToolCallId("1".into()),
2040            title: "Thinking".into(),
2041            kind: acp::ToolKind::Think,
2042            status: acp::ToolCallStatus::Pending,
2043            content: vec![],
2044            locations: vec![],
2045            raw_input: Some(json!({})),
2046            raw_output: None,
2047            meta: Some(json!({ "tool_name": "thinking" })),
2048        }
2049    );
2050    let update = expect_tool_call_update_fields(&mut events).await;
2051    assert_eq!(
2052        update,
2053        acp::ToolCallUpdate {
2054            id: acp::ToolCallId("1".into()),
2055            fields: acp::ToolCallUpdateFields {
2056                title: Some("Thinking".into()),
2057                kind: Some(acp::ToolKind::Think),
2058                raw_input: Some(json!({ "content": "Thinking hard!" })),
2059                ..Default::default()
2060            },
2061            meta: None,
2062        }
2063    );
2064    let update = expect_tool_call_update_fields(&mut events).await;
2065    assert_eq!(
2066        update,
2067        acp::ToolCallUpdate {
2068            id: acp::ToolCallId("1".into()),
2069            fields: acp::ToolCallUpdateFields {
2070                status: Some(acp::ToolCallStatus::InProgress),
2071                ..Default::default()
2072            },
2073            meta: None,
2074        }
2075    );
2076    let update = expect_tool_call_update_fields(&mut events).await;
2077    assert_eq!(
2078        update,
2079        acp::ToolCallUpdate {
2080            id: acp::ToolCallId("1".into()),
2081            fields: acp::ToolCallUpdateFields {
2082                content: Some(vec!["Thinking hard!".into()]),
2083                ..Default::default()
2084            },
2085            meta: None,
2086        }
2087    );
2088    let update = expect_tool_call_update_fields(&mut events).await;
2089    assert_eq!(
2090        update,
2091        acp::ToolCallUpdate {
2092            id: acp::ToolCallId("1".into()),
2093            fields: acp::ToolCallUpdateFields {
2094                status: Some(acp::ToolCallStatus::Completed),
2095                raw_output: Some("Finished thinking.".into()),
2096                ..Default::default()
2097            },
2098            meta: None,
2099        }
2100    );
2101}
2102
2103#[gpui::test]
2104async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2105    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2106    let fake_model = model.as_fake();
2107
2108    let mut events = thread
2109        .update(cx, |thread, cx| {
2110            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2111            thread.send(UserMessageId::new(), ["Hello!"], cx)
2112        })
2113        .unwrap();
2114    cx.run_until_parked();
2115
2116    fake_model.send_last_completion_stream_text_chunk("Hey!");
2117    fake_model.end_last_completion_stream();
2118
2119    let mut retry_events = Vec::new();
2120    while let Some(Ok(event)) = events.next().await {
2121        match event {
2122            ThreadEvent::Retry(retry_status) => {
2123                retry_events.push(retry_status);
2124            }
2125            ThreadEvent::Stop(..) => break,
2126            _ => {}
2127        }
2128    }
2129
2130    assert_eq!(retry_events.len(), 0);
2131    thread.read_with(cx, |thread, _cx| {
2132        assert_eq!(
2133            thread.to_markdown(),
2134            indoc! {"
2135                ## User
2136
2137                Hello!
2138
2139                ## Assistant
2140
2141                Hey!
2142            "}
2143        )
2144    });
2145}
2146
2147#[gpui::test]
2148async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2149    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2150    let fake_model = model.as_fake();
2151
2152    let mut events = thread
2153        .update(cx, |thread, cx| {
2154            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2155            thread.send(UserMessageId::new(), ["Hello!"], cx)
2156        })
2157        .unwrap();
2158    cx.run_until_parked();
2159
2160    fake_model.send_last_completion_stream_text_chunk("Hey,");
2161    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2162        provider: LanguageModelProviderName::new("Anthropic"),
2163        retry_after: Some(Duration::from_secs(3)),
2164    });
2165    fake_model.end_last_completion_stream();
2166
2167    cx.executor().advance_clock(Duration::from_secs(3));
2168    cx.run_until_parked();
2169
2170    fake_model.send_last_completion_stream_text_chunk("there!");
2171    fake_model.end_last_completion_stream();
2172    cx.run_until_parked();
2173
2174    let mut retry_events = Vec::new();
2175    while let Some(Ok(event)) = events.next().await {
2176        match event {
2177            ThreadEvent::Retry(retry_status) => {
2178                retry_events.push(retry_status);
2179            }
2180            ThreadEvent::Stop(..) => break,
2181            _ => {}
2182        }
2183    }
2184
2185    assert_eq!(retry_events.len(), 1);
2186    assert!(matches!(
2187        retry_events[0],
2188        acp_thread::RetryStatus { attempt: 1, .. }
2189    ));
2190    thread.read_with(cx, |thread, _cx| {
2191        assert_eq!(
2192            thread.to_markdown(),
2193            indoc! {"
2194                ## User
2195
2196                Hello!
2197
2198                ## Assistant
2199
2200                Hey,
2201
2202                [resume]
2203
2204                ## Assistant
2205
2206                there!
2207            "}
2208        )
2209    });
2210}
2211
2212#[gpui::test]
2213async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2214    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2215    let fake_model = model.as_fake();
2216
2217    let events = thread
2218        .update(cx, |thread, cx| {
2219            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2220            thread.add_tool(EchoTool);
2221            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2222        })
2223        .unwrap();
2224    cx.run_until_parked();
2225
2226    let tool_use_1 = LanguageModelToolUse {
2227        id: "tool_1".into(),
2228        name: EchoTool::name().into(),
2229        raw_input: json!({"text": "test"}).to_string(),
2230        input: json!({"text": "test"}),
2231        is_input_complete: true,
2232        thought_signature: None,
2233    };
2234    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2235        tool_use_1.clone(),
2236    ));
2237    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2238        provider: LanguageModelProviderName::new("Anthropic"),
2239        retry_after: Some(Duration::from_secs(3)),
2240    });
2241    fake_model.end_last_completion_stream();
2242
2243    cx.executor().advance_clock(Duration::from_secs(3));
2244    let completion = fake_model.pending_completions().pop().unwrap();
2245    assert_eq!(
2246        completion.messages[1..],
2247        vec![
2248            LanguageModelRequestMessage {
2249                role: Role::User,
2250                content: vec!["Call the echo tool!".into()],
2251                cache: false
2252            },
2253            LanguageModelRequestMessage {
2254                role: Role::Assistant,
2255                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2256                cache: false
2257            },
2258            LanguageModelRequestMessage {
2259                role: Role::User,
2260                content: vec![language_model::MessageContent::ToolResult(
2261                    LanguageModelToolResult {
2262                        tool_use_id: tool_use_1.id.clone(),
2263                        tool_name: tool_use_1.name.clone(),
2264                        is_error: false,
2265                        content: "test".into(),
2266                        output: Some("test".into())
2267                    }
2268                )],
2269                cache: true
2270            },
2271        ]
2272    );
2273
2274    fake_model.send_last_completion_stream_text_chunk("Done");
2275    fake_model.end_last_completion_stream();
2276    cx.run_until_parked();
2277    events.collect::<Vec<_>>().await;
2278    thread.read_with(cx, |thread, _cx| {
2279        assert_eq!(
2280            thread.last_message(),
2281            Some(Message::Agent(AgentMessage {
2282                content: vec![AgentMessageContent::Text("Done".into())],
2283                tool_results: IndexMap::default()
2284            }))
2285        );
2286    })
2287}
2288
2289#[gpui::test]
2290async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2291    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2292    let fake_model = model.as_fake();
2293
2294    let mut events = thread
2295        .update(cx, |thread, cx| {
2296            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2297            thread.send(UserMessageId::new(), ["Hello!"], cx)
2298        })
2299        .unwrap();
2300    cx.run_until_parked();
2301
2302    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2303        fake_model.send_last_completion_stream_error(
2304            LanguageModelCompletionError::ServerOverloaded {
2305                provider: LanguageModelProviderName::new("Anthropic"),
2306                retry_after: Some(Duration::from_secs(3)),
2307            },
2308        );
2309        fake_model.end_last_completion_stream();
2310        cx.executor().advance_clock(Duration::from_secs(3));
2311        cx.run_until_parked();
2312    }
2313
2314    let mut errors = Vec::new();
2315    let mut retry_events = Vec::new();
2316    while let Some(event) = events.next().await {
2317        match event {
2318            Ok(ThreadEvent::Retry(retry_status)) => {
2319                retry_events.push(retry_status);
2320            }
2321            Ok(ThreadEvent::Stop(..)) => break,
2322            Err(error) => errors.push(error),
2323            _ => {}
2324        }
2325    }
2326
2327    assert_eq!(
2328        retry_events.len(),
2329        crate::thread::MAX_RETRY_ATTEMPTS as usize
2330    );
2331    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2332        assert_eq!(retry_events[i].attempt, i + 1);
2333    }
2334    assert_eq!(errors.len(), 1);
2335    let error = errors[0]
2336        .downcast_ref::<LanguageModelCompletionError>()
2337        .unwrap();
2338    assert!(matches!(
2339        error,
2340        LanguageModelCompletionError::ServerOverloaded { .. }
2341    ));
2342}
2343
2344/// Filters out the stop events for asserting against in tests
2345fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2346    result_events
2347        .into_iter()
2348        .filter_map(|event| match event.unwrap() {
2349            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2350            _ => None,
2351        })
2352        .collect()
2353}
2354
2355struct ThreadTest {
2356    model: Arc<dyn LanguageModel>,
2357    thread: Entity<Thread>,
2358    project_context: Entity<ProjectContext>,
2359    context_server_store: Entity<ContextServerStore>,
2360    fs: Arc<FakeFs>,
2361}
2362
2363enum TestModel {
2364    Sonnet4,
2365    Fake,
2366}
2367
2368impl TestModel {
2369    fn id(&self) -> LanguageModelId {
2370        match self {
2371            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2372            TestModel::Fake => unreachable!(),
2373        }
2374    }
2375}
2376
2377async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2378    cx.executor().allow_parking();
2379
2380    let fs = FakeFs::new(cx.background_executor.clone());
2381    fs.create_dir(paths::settings_file().parent().unwrap())
2382        .await
2383        .unwrap();
2384    fs.insert_file(
2385        paths::settings_file(),
2386        json!({
2387            "agent": {
2388                "default_profile": "test-profile",
2389                "profiles": {
2390                    "test-profile": {
2391                        "name": "Test Profile",
2392                        "tools": {
2393                            EchoTool::name(): true,
2394                            DelayTool::name(): true,
2395                            WordListTool::name(): true,
2396                            ToolRequiringPermission::name(): true,
2397                            InfiniteTool::name(): true,
2398                            ThinkingTool::name(): true,
2399                        }
2400                    }
2401                }
2402            }
2403        })
2404        .to_string()
2405        .into_bytes(),
2406    )
2407    .await;
2408
2409    cx.update(|cx| {
2410        settings::init(cx);
2411
2412        match model {
2413            TestModel::Fake => {}
2414            TestModel::Sonnet4 => {
2415                gpui_tokio::init(cx);
2416                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2417                cx.set_http_client(Arc::new(http_client));
2418                let client = Client::production(cx);
2419                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2420                language_model::init(client.clone(), cx);
2421                language_models::init(user_store, client.clone(), cx);
2422            }
2423        };
2424
2425        watch_settings(fs.clone(), cx);
2426    });
2427
2428    let templates = Templates::new();
2429
2430    fs.insert_tree(path!("/test"), json!({})).await;
2431    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2432
2433    let model = cx
2434        .update(|cx| {
2435            if let TestModel::Fake = model {
2436                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2437            } else {
2438                let model_id = model.id();
2439                let models = LanguageModelRegistry::read_global(cx);
2440                let model = models
2441                    .available_models(cx)
2442                    .find(|model| model.id() == model_id)
2443                    .unwrap();
2444
2445                let provider = models.provider(&model.provider_id()).unwrap();
2446                let authenticated = provider.authenticate(cx);
2447
2448                cx.spawn(async move |_cx| {
2449                    authenticated.await.unwrap();
2450                    model
2451                })
2452            }
2453        })
2454        .await;
2455
2456    let project_context = cx.new(|_cx| ProjectContext::default());
2457    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2458    let context_server_registry =
2459        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2460    let thread = cx.new(|cx| {
2461        Thread::new(
2462            project,
2463            project_context.clone(),
2464            context_server_registry,
2465            templates,
2466            Some(model.clone()),
2467            cx,
2468        )
2469    });
2470    ThreadTest {
2471        model,
2472        thread,
2473        project_context,
2474        context_server_store,
2475        fs,
2476    }
2477}
2478
2479#[cfg(test)]
2480#[ctor::ctor]
2481fn init_logger() {
2482    if std::env::var("RUST_LOG").is_ok() {
2483        env_logger::init();
2484    }
2485}
2486
2487fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2488    let fs = fs.clone();
2489    cx.spawn({
2490        async move |cx| {
2491            let mut new_settings_content_rx = settings::watch_config_file(
2492                cx.background_executor(),
2493                fs,
2494                paths::settings_file().clone(),
2495            );
2496
2497            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2498                cx.update(|cx| {
2499                    SettingsStore::update_global(cx, |settings, cx| {
2500                        settings.set_user_settings(&new_settings_content, cx)
2501                    })
2502                })
2503                .ok();
2504            }
2505        }
2506    })
2507    .detach();
2508}
2509
2510fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2511    completion
2512        .tools
2513        .iter()
2514        .map(|tool| tool.name.clone())
2515        .collect()
2516}
2517
2518fn setup_context_server(
2519    name: &'static str,
2520    tools: Vec<context_server::types::Tool>,
2521    context_server_store: &Entity<ContextServerStore>,
2522    cx: &mut TestAppContext,
2523) -> mpsc::UnboundedReceiver<(
2524    context_server::types::CallToolParams,
2525    oneshot::Sender<context_server::types::CallToolResponse>,
2526)> {
2527    cx.update(|cx| {
2528        let mut settings = ProjectSettings::get_global(cx).clone();
2529        settings.context_servers.insert(
2530            name.into(),
2531            project::project_settings::ContextServerSettings::Custom {
2532                enabled: true,
2533                command: ContextServerCommand {
2534                    path: "somebinary".into(),
2535                    args: Vec::new(),
2536                    env: None,
2537                    timeout: None,
2538                },
2539            },
2540        );
2541        ProjectSettings::override_global(settings, cx);
2542    });
2543
2544    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2545    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2546        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2547            context_server::types::InitializeResponse {
2548                protocol_version: context_server::types::ProtocolVersion(
2549                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2550                ),
2551                server_info: context_server::types::Implementation {
2552                    name: name.into(),
2553                    version: "1.0.0".to_string(),
2554                },
2555                capabilities: context_server::types::ServerCapabilities {
2556                    tools: Some(context_server::types::ToolsCapabilities {
2557                        list_changed: Some(true),
2558                    }),
2559                    ..Default::default()
2560                },
2561                meta: None,
2562            }
2563        })
2564        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2565            let tools = tools.clone();
2566            async move {
2567                context_server::types::ListToolsResponse {
2568                    tools,
2569                    next_cursor: None,
2570                    meta: None,
2571                }
2572            }
2573        })
2574        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2575            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2576            async move {
2577                let (response_tx, response_rx) = oneshot::channel();
2578                mcp_tool_calls_tx
2579                    .unbounded_send((params, response_tx))
2580                    .unwrap();
2581                response_rx.await.unwrap()
2582            }
2583        });
2584    context_server_store.update(cx, |store, cx| {
2585        store.start_server(
2586            Arc::new(ContextServer::new(
2587                ContextServerId(name.into()),
2588                Arc::new(fake_transport),
2589            )),
2590            cx,
2591        );
2592    });
2593    cx.run_until_parked();
2594    mcp_tool_calls_rx
2595}