mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::{ProjectContext, UserPromptId, UserRulesContext};
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39
  40use util::path;
  41
  42mod test_tools;
  43use test_tools::*;
  44
  45#[gpui::test]
  46async fn test_echo(cx: &mut TestAppContext) {
  47    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  48    let fake_model = model.as_fake();
  49
  50    let events = thread
  51        .update(cx, |thread, cx| {
  52            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  53        })
  54        .unwrap();
  55    cx.run_until_parked();
  56    fake_model.send_last_completion_stream_text_chunk("Hello");
  57    fake_model
  58        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  59    fake_model.end_last_completion_stream();
  60
  61    let events = events.collect().await;
  62    thread.update(cx, |thread, _cx| {
  63        assert_eq!(
  64            thread.last_message().unwrap().to_markdown(),
  65            indoc! {"
  66                ## Assistant
  67
  68                Hello
  69            "}
  70        )
  71    });
  72    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  73}
  74
  75#[gpui::test]
  76async fn test_thinking(cx: &mut TestAppContext) {
  77    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  78    let fake_model = model.as_fake();
  79
  80    let events = thread
  81        .update(cx, |thread, cx| {
  82            thread.send(
  83                UserMessageId::new(),
  84                [indoc! {"
  85                    Testing:
  86
  87                    Generate a thinking step where you just think the word 'Think',
  88                    and have your final answer be 'Hello'
  89                "}],
  90                cx,
  91            )
  92        })
  93        .unwrap();
  94    cx.run_until_parked();
  95    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  96        text: "Think".to_string(),
  97        signature: None,
  98    });
  99    fake_model.send_last_completion_stream_text_chunk("Hello");
 100    fake_model
 101        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 102    fake_model.end_last_completion_stream();
 103
 104    let events = events.collect().await;
 105    thread.update(cx, |thread, _cx| {
 106        assert_eq!(
 107            thread.last_message().unwrap().to_markdown(),
 108            indoc! {"
 109                ## Assistant
 110
 111                <think>Think</think>
 112                Hello
 113            "}
 114        )
 115    });
 116    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 117}
 118
 119#[gpui::test]
 120async fn test_system_prompt(cx: &mut TestAppContext) {
 121    let ThreadTest {
 122        model,
 123        thread,
 124        project_context,
 125        ..
 126    } = setup(cx, TestModel::Fake).await;
 127    let fake_model = model.as_fake();
 128
 129    project_context.update(cx, |project_context, _cx| {
 130        project_context.shell = "test-shell".into()
 131    });
 132    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 133    thread
 134        .update(cx, |thread, cx| {
 135            thread.send(UserMessageId::new(), ["abc"], cx)
 136        })
 137        .unwrap();
 138    cx.run_until_parked();
 139    let mut pending_completions = fake_model.pending_completions();
 140    assert_eq!(
 141        pending_completions.len(),
 142        1,
 143        "unexpected pending completions: {:?}",
 144        pending_completions
 145    );
 146
 147    let pending_completion = pending_completions.pop().unwrap();
 148    assert_eq!(pending_completion.messages[0].role, Role::System);
 149
 150    let system_message = &pending_completion.messages[0];
 151    let system_prompt = system_message.content[0].to_str().unwrap();
 152    assert!(
 153        system_prompt.contains("test-shell"),
 154        "unexpected system message: {:?}",
 155        system_message
 156    );
 157    assert!(
 158        system_prompt.contains("## Fixing Diagnostics"),
 159        "unexpected system message: {:?}",
 160        system_message
 161    );
 162}
 163
 164#[gpui::test]
 165async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 166    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 167    let fake_model = model.as_fake();
 168
 169    thread
 170        .update(cx, |thread, cx| {
 171            thread.send(UserMessageId::new(), ["abc"], cx)
 172        })
 173        .unwrap();
 174    cx.run_until_parked();
 175    let mut pending_completions = fake_model.pending_completions();
 176    assert_eq!(
 177        pending_completions.len(),
 178        1,
 179        "unexpected pending completions: {:?}",
 180        pending_completions
 181    );
 182
 183    let pending_completion = pending_completions.pop().unwrap();
 184    assert_eq!(pending_completion.messages[0].role, Role::System);
 185
 186    let system_message = &pending_completion.messages[0];
 187    let system_prompt = system_message.content[0].to_str().unwrap();
 188    assert!(
 189        !system_prompt.contains("## Tool Use"),
 190        "unexpected system message: {:?}",
 191        system_message
 192    );
 193    assert!(
 194        !system_prompt.contains("## Fixing Diagnostics"),
 195        "unexpected system message: {:?}",
 196        system_message
 197    );
 198}
 199
 200#[gpui::test]
 201async fn test_prompt_caching(cx: &mut TestAppContext) {
 202    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 203    let fake_model = model.as_fake();
 204
 205    // Send initial user message and verify it's cached
 206    thread
 207        .update(cx, |thread, cx| {
 208            thread.send(UserMessageId::new(), ["Message 1"], cx)
 209        })
 210        .unwrap();
 211    cx.run_until_parked();
 212
 213    let completion = fake_model.pending_completions().pop().unwrap();
 214    assert_eq!(
 215        completion.messages[1..],
 216        vec![LanguageModelRequestMessage {
 217            role: Role::User,
 218            content: vec!["Message 1".into()],
 219            cache: true
 220        }]
 221    );
 222    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 223        "Response to Message 1".into(),
 224    ));
 225    fake_model.end_last_completion_stream();
 226    cx.run_until_parked();
 227
 228    // Send another user message and verify only the latest is cached
 229    thread
 230        .update(cx, |thread, cx| {
 231            thread.send(UserMessageId::new(), ["Message 2"], cx)
 232        })
 233        .unwrap();
 234    cx.run_until_parked();
 235
 236    let completion = fake_model.pending_completions().pop().unwrap();
 237    assert_eq!(
 238        completion.messages[1..],
 239        vec![
 240            LanguageModelRequestMessage {
 241                role: Role::User,
 242                content: vec!["Message 1".into()],
 243                cache: false
 244            },
 245            LanguageModelRequestMessage {
 246                role: Role::Assistant,
 247                content: vec!["Response to Message 1".into()],
 248                cache: false
 249            },
 250            LanguageModelRequestMessage {
 251                role: Role::User,
 252                content: vec!["Message 2".into()],
 253                cache: true
 254            }
 255        ]
 256    );
 257    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 258        "Response to Message 2".into(),
 259    ));
 260    fake_model.end_last_completion_stream();
 261    cx.run_until_parked();
 262
 263    // Simulate a tool call and verify that the latest tool result is cached
 264    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 265    thread
 266        .update(cx, |thread, cx| {
 267            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 268        })
 269        .unwrap();
 270    cx.run_until_parked();
 271
 272    let tool_use = LanguageModelToolUse {
 273        id: "tool_1".into(),
 274        name: EchoTool::name().into(),
 275        raw_input: json!({"text": "test"}).to_string(),
 276        input: json!({"text": "test"}),
 277        is_input_complete: true,
 278    };
 279    fake_model
 280        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 281    fake_model.end_last_completion_stream();
 282    cx.run_until_parked();
 283
 284    let completion = fake_model.pending_completions().pop().unwrap();
 285    let tool_result = LanguageModelToolResult {
 286        tool_use_id: "tool_1".into(),
 287        tool_name: EchoTool::name().into(),
 288        is_error: false,
 289        content: "test".into(),
 290        output: Some("test".into()),
 291    };
 292    assert_eq!(
 293        completion.messages[1..],
 294        vec![
 295            LanguageModelRequestMessage {
 296                role: Role::User,
 297                content: vec!["Message 1".into()],
 298                cache: false
 299            },
 300            LanguageModelRequestMessage {
 301                role: Role::Assistant,
 302                content: vec!["Response to Message 1".into()],
 303                cache: false
 304            },
 305            LanguageModelRequestMessage {
 306                role: Role::User,
 307                content: vec!["Message 2".into()],
 308                cache: false
 309            },
 310            LanguageModelRequestMessage {
 311                role: Role::Assistant,
 312                content: vec!["Response to Message 2".into()],
 313                cache: false
 314            },
 315            LanguageModelRequestMessage {
 316                role: Role::User,
 317                content: vec!["Use the echo tool".into()],
 318                cache: false
 319            },
 320            LanguageModelRequestMessage {
 321                role: Role::Assistant,
 322                content: vec![MessageContent::ToolUse(tool_use)],
 323                cache: false
 324            },
 325            LanguageModelRequestMessage {
 326                role: Role::User,
 327                content: vec![MessageContent::ToolResult(tool_result)],
 328                cache: true
 329            }
 330        ]
 331    );
 332}
 333
 334#[gpui::test]
 335#[cfg_attr(not(feature = "e2e"), ignore)]
 336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 337    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 338
 339    // Test a tool call that's likely to complete *before* streaming stops.
 340    let events = thread
 341        .update(cx, |thread, cx| {
 342            thread.add_tool(EchoTool);
 343            thread.send(
 344                UserMessageId::new(),
 345                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 346                cx,
 347            )
 348        })
 349        .unwrap()
 350        .collect()
 351        .await;
 352    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 353
 354    // Test a tool calls that's likely to complete *after* streaming stops.
 355    let events = thread
 356        .update(cx, |thread, cx| {
 357            thread.remove_tool(&EchoTool::name());
 358            thread.add_tool(DelayTool);
 359            thread.send(
 360                UserMessageId::new(),
 361                [
 362                    "Now call the delay tool with 200ms.",
 363                    "When the timer goes off, then you echo the output of the tool.",
 364                ],
 365                cx,
 366            )
 367        })
 368        .unwrap()
 369        .collect()
 370        .await;
 371    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 372    thread.update(cx, |thread, _cx| {
 373        assert!(
 374            thread
 375                .last_message()
 376                .unwrap()
 377                .as_agent_message()
 378                .unwrap()
 379                .content
 380                .iter()
 381                .any(|content| {
 382                    if let AgentMessageContent::Text(text) = content {
 383                        text.contains("Ding")
 384                    } else {
 385                        false
 386                    }
 387                }),
 388            "{}",
 389            thread.to_markdown()
 390        );
 391    });
 392}
 393
 394#[gpui::test]
 395#[cfg_attr(not(feature = "e2e"), ignore)]
 396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 397    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 398
 399    // Test a tool call that's likely to complete *before* streaming stops.
 400    let mut events = thread
 401        .update(cx, |thread, cx| {
 402            thread.add_tool(WordListTool);
 403            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 404        })
 405        .unwrap();
 406
 407    let mut saw_partial_tool_use = false;
 408    while let Some(event) = events.next().await {
 409        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 410            thread.update(cx, |thread, _cx| {
 411                // Look for a tool use in the thread's last message
 412                let message = thread.last_message().unwrap();
 413                let agent_message = message.as_agent_message().unwrap();
 414                let last_content = agent_message.content.last().unwrap();
 415                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 416                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 417                    if tool_call.status == acp::ToolCallStatus::Pending {
 418                        if !last_tool_use.is_input_complete
 419                            && last_tool_use.input.get("g").is_none()
 420                        {
 421                            saw_partial_tool_use = true;
 422                        }
 423                    } else {
 424                        last_tool_use
 425                            .input
 426                            .get("a")
 427                            .expect("'a' has streamed because input is now complete");
 428                        last_tool_use
 429                            .input
 430                            .get("g")
 431                            .expect("'g' has streamed because input is now complete");
 432                    }
 433                } else {
 434                    panic!("last content should be a tool use");
 435                }
 436            });
 437        }
 438    }
 439
 440    assert!(
 441        saw_partial_tool_use,
 442        "should see at least one partially streamed tool use in the history"
 443    );
 444}
 445
 446#[gpui::test]
 447async fn test_tool_authorization(cx: &mut TestAppContext) {
 448    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 449    let fake_model = model.as_fake();
 450
 451    let mut events = thread
 452        .update(cx, |thread, cx| {
 453            thread.add_tool(ToolRequiringPermission);
 454            thread.send(UserMessageId::new(), ["abc"], cx)
 455        })
 456        .unwrap();
 457    cx.run_until_parked();
 458    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 459        LanguageModelToolUse {
 460            id: "tool_id_1".into(),
 461            name: ToolRequiringPermission::name().into(),
 462            raw_input: "{}".into(),
 463            input: json!({}),
 464            is_input_complete: true,
 465        },
 466    ));
 467    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 468        LanguageModelToolUse {
 469            id: "tool_id_2".into(),
 470            name: ToolRequiringPermission::name().into(),
 471            raw_input: "{}".into(),
 472            input: json!({}),
 473            is_input_complete: true,
 474        },
 475    ));
 476    fake_model.end_last_completion_stream();
 477    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 478    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 479
 480    // Approve the first
 481    tool_call_auth_1
 482        .response
 483        .send(tool_call_auth_1.options[1].id.clone())
 484        .unwrap();
 485    cx.run_until_parked();
 486
 487    // Reject the second
 488    tool_call_auth_2
 489        .response
 490        .send(tool_call_auth_1.options[2].id.clone())
 491        .unwrap();
 492    cx.run_until_parked();
 493
 494    let completion = fake_model.pending_completions().pop().unwrap();
 495    let message = completion.messages.last().unwrap();
 496    assert_eq!(
 497        message.content,
 498        vec![
 499            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 500                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 501                tool_name: ToolRequiringPermission::name().into(),
 502                is_error: false,
 503                content: "Allowed".into(),
 504                output: Some("Allowed".into())
 505            }),
 506            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 507                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 508                tool_name: ToolRequiringPermission::name().into(),
 509                is_error: true,
 510                content: "Permission to run tool denied by user".into(),
 511                output: Some("Permission to run tool denied by user".into())
 512            })
 513        ]
 514    );
 515
 516    // Simulate yet another tool call.
 517    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 518        LanguageModelToolUse {
 519            id: "tool_id_3".into(),
 520            name: ToolRequiringPermission::name().into(),
 521            raw_input: "{}".into(),
 522            input: json!({}),
 523            is_input_complete: true,
 524        },
 525    ));
 526    fake_model.end_last_completion_stream();
 527
 528    // Respond by always allowing tools.
 529    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 530    tool_call_auth_3
 531        .response
 532        .send(tool_call_auth_3.options[0].id.clone())
 533        .unwrap();
 534    cx.run_until_parked();
 535    let completion = fake_model.pending_completions().pop().unwrap();
 536    let message = completion.messages.last().unwrap();
 537    assert_eq!(
 538        message.content,
 539        vec![language_model::MessageContent::ToolResult(
 540            LanguageModelToolResult {
 541                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 542                tool_name: ToolRequiringPermission::name().into(),
 543                is_error: false,
 544                content: "Allowed".into(),
 545                output: Some("Allowed".into())
 546            }
 547        )]
 548    );
 549
 550    // Simulate a final tool call, ensuring we don't trigger authorization.
 551    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 552        LanguageModelToolUse {
 553            id: "tool_id_4".into(),
 554            name: ToolRequiringPermission::name().into(),
 555            raw_input: "{}".into(),
 556            input: json!({}),
 557            is_input_complete: true,
 558        },
 559    ));
 560    fake_model.end_last_completion_stream();
 561    cx.run_until_parked();
 562    let completion = fake_model.pending_completions().pop().unwrap();
 563    let message = completion.messages.last().unwrap();
 564    assert_eq!(
 565        message.content,
 566        vec![language_model::MessageContent::ToolResult(
 567            LanguageModelToolResult {
 568                tool_use_id: "tool_id_4".into(),
 569                tool_name: ToolRequiringPermission::name().into(),
 570                is_error: false,
 571                content: "Allowed".into(),
 572                output: Some("Allowed".into())
 573            }
 574        )]
 575    );
 576}
 577
 578#[gpui::test]
 579async fn test_tool_hallucination(cx: &mut TestAppContext) {
 580    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 581    let fake_model = model.as_fake();
 582
 583    let mut events = thread
 584        .update(cx, |thread, cx| {
 585            thread.send(UserMessageId::new(), ["abc"], cx)
 586        })
 587        .unwrap();
 588    cx.run_until_parked();
 589    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 590        LanguageModelToolUse {
 591            id: "tool_id_1".into(),
 592            name: "nonexistent_tool".into(),
 593            raw_input: "{}".into(),
 594            input: json!({}),
 595            is_input_complete: true,
 596        },
 597    ));
 598    fake_model.end_last_completion_stream();
 599
 600    let tool_call = expect_tool_call(&mut events).await;
 601    assert_eq!(tool_call.title, "nonexistent_tool");
 602    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 603    let update = expect_tool_call_update_fields(&mut events).await;
 604    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 605}
 606
 607#[gpui::test]
 608async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 609    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 610    let fake_model = model.as_fake();
 611
 612    let events = thread
 613        .update(cx, |thread, cx| {
 614            thread.add_tool(EchoTool);
 615            thread.send(UserMessageId::new(), ["abc"], cx)
 616        })
 617        .unwrap();
 618    cx.run_until_parked();
 619    let tool_use = LanguageModelToolUse {
 620        id: "tool_id_1".into(),
 621        name: EchoTool::name().into(),
 622        raw_input: "{}".into(),
 623        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 624        is_input_complete: true,
 625    };
 626    fake_model
 627        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 628    fake_model.end_last_completion_stream();
 629
 630    cx.run_until_parked();
 631    let completion = fake_model.pending_completions().pop().unwrap();
 632    let tool_result = LanguageModelToolResult {
 633        tool_use_id: "tool_id_1".into(),
 634        tool_name: EchoTool::name().into(),
 635        is_error: false,
 636        content: "def".into(),
 637        output: Some("def".into()),
 638    };
 639    assert_eq!(
 640        completion.messages[1..],
 641        vec![
 642            LanguageModelRequestMessage {
 643                role: Role::User,
 644                content: vec!["abc".into()],
 645                cache: false
 646            },
 647            LanguageModelRequestMessage {
 648                role: Role::Assistant,
 649                content: vec![MessageContent::ToolUse(tool_use.clone())],
 650                cache: false
 651            },
 652            LanguageModelRequestMessage {
 653                role: Role::User,
 654                content: vec![MessageContent::ToolResult(tool_result.clone())],
 655                cache: true
 656            },
 657        ]
 658    );
 659
 660    // Simulate reaching tool use limit.
 661    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 662        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 663    ));
 664    fake_model.end_last_completion_stream();
 665    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 666    assert!(
 667        last_event
 668            .unwrap_err()
 669            .is::<language_model::ToolUseLimitReachedError>()
 670    );
 671
 672    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 673    cx.run_until_parked();
 674    let completion = fake_model.pending_completions().pop().unwrap();
 675    assert_eq!(
 676        completion.messages[1..],
 677        vec![
 678            LanguageModelRequestMessage {
 679                role: Role::User,
 680                content: vec!["abc".into()],
 681                cache: false
 682            },
 683            LanguageModelRequestMessage {
 684                role: Role::Assistant,
 685                content: vec![MessageContent::ToolUse(tool_use)],
 686                cache: false
 687            },
 688            LanguageModelRequestMessage {
 689                role: Role::User,
 690                content: vec![MessageContent::ToolResult(tool_result)],
 691                cache: false
 692            },
 693            LanguageModelRequestMessage {
 694                role: Role::User,
 695                content: vec!["Continue where you left off".into()],
 696                cache: true
 697            }
 698        ]
 699    );
 700
 701    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 702    fake_model.end_last_completion_stream();
 703    events.collect::<Vec<_>>().await;
 704    thread.read_with(cx, |thread, _cx| {
 705        assert_eq!(
 706            thread.last_message().unwrap().to_markdown(),
 707            indoc! {"
 708                ## Assistant
 709
 710                Done
 711            "}
 712        )
 713    });
 714}
 715
 716#[gpui::test]
 717async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 718    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 719    let fake_model = model.as_fake();
 720
 721    let events = thread
 722        .update(cx, |thread, cx| {
 723            thread.add_tool(EchoTool);
 724            thread.send(UserMessageId::new(), ["abc"], cx)
 725        })
 726        .unwrap();
 727    cx.run_until_parked();
 728
 729    let tool_use = LanguageModelToolUse {
 730        id: "tool_id_1".into(),
 731        name: EchoTool::name().into(),
 732        raw_input: "{}".into(),
 733        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 734        is_input_complete: true,
 735    };
 736    let tool_result = LanguageModelToolResult {
 737        tool_use_id: "tool_id_1".into(),
 738        tool_name: EchoTool::name().into(),
 739        is_error: false,
 740        content: "def".into(),
 741        output: Some("def".into()),
 742    };
 743    fake_model
 744        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 745    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 746        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 747    ));
 748    fake_model.end_last_completion_stream();
 749    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 750    assert!(
 751        last_event
 752            .unwrap_err()
 753            .is::<language_model::ToolUseLimitReachedError>()
 754    );
 755
 756    thread
 757        .update(cx, |thread, cx| {
 758            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 759        })
 760        .unwrap();
 761    cx.run_until_parked();
 762    let completion = fake_model.pending_completions().pop().unwrap();
 763    assert_eq!(
 764        completion.messages[1..],
 765        vec![
 766            LanguageModelRequestMessage {
 767                role: Role::User,
 768                content: vec!["abc".into()],
 769                cache: false
 770            },
 771            LanguageModelRequestMessage {
 772                role: Role::Assistant,
 773                content: vec![MessageContent::ToolUse(tool_use)],
 774                cache: false
 775            },
 776            LanguageModelRequestMessage {
 777                role: Role::User,
 778                content: vec![MessageContent::ToolResult(tool_result)],
 779                cache: false
 780            },
 781            LanguageModelRequestMessage {
 782                role: Role::User,
 783                content: vec!["ghi".into()],
 784                cache: true
 785            }
 786        ]
 787    );
 788}
 789
 790async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 791    let event = events
 792        .next()
 793        .await
 794        .expect("no tool call authorization event received")
 795        .unwrap();
 796    match event {
 797        ThreadEvent::ToolCall(tool_call) => tool_call,
 798        event => {
 799            panic!("Unexpected event {event:?}");
 800        }
 801    }
 802}
 803
 804async fn expect_tool_call_update_fields(
 805    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 806) -> acp::ToolCallUpdate {
 807    let event = events
 808        .next()
 809        .await
 810        .expect("no tool call authorization event received")
 811        .unwrap();
 812    match event {
 813        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 814        event => {
 815            panic!("Unexpected event {event:?}");
 816        }
 817    }
 818}
 819
 820async fn next_tool_call_authorization(
 821    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 822) -> ToolCallAuthorization {
 823    loop {
 824        let event = events
 825            .next()
 826            .await
 827            .expect("no tool call authorization event received")
 828            .unwrap();
 829        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 830            let permission_kinds = tool_call_authorization
 831                .options
 832                .iter()
 833                .map(|o| o.kind)
 834                .collect::<Vec<_>>();
 835            assert_eq!(
 836                permission_kinds,
 837                vec![
 838                    acp::PermissionOptionKind::AllowAlways,
 839                    acp::PermissionOptionKind::AllowOnce,
 840                    acp::PermissionOptionKind::RejectOnce,
 841                ]
 842            );
 843            return tool_call_authorization;
 844        }
 845    }
 846}
 847
 848#[gpui::test]
 849#[cfg_attr(not(feature = "e2e"), ignore)]
 850async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 851    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 852
 853    // Test concurrent tool calls with different delay times
 854    let events = thread
 855        .update(cx, |thread, cx| {
 856            thread.add_tool(DelayTool);
 857            thread.send(
 858                UserMessageId::new(),
 859                [
 860                    "Call the delay tool twice in the same message.",
 861                    "Once with 100ms. Once with 300ms.",
 862                    "When both timers are complete, describe the outputs.",
 863                ],
 864                cx,
 865            )
 866        })
 867        .unwrap()
 868        .collect()
 869        .await;
 870
 871    let stop_reasons = stop_events(events);
 872    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 873
 874    thread.update(cx, |thread, _cx| {
 875        let last_message = thread.last_message().unwrap();
 876        let agent_message = last_message.as_agent_message().unwrap();
 877        let text = agent_message
 878            .content
 879            .iter()
 880            .filter_map(|content| {
 881                if let AgentMessageContent::Text(text) = content {
 882                    Some(text.as_str())
 883                } else {
 884                    None
 885                }
 886            })
 887            .collect::<String>();
 888
 889        assert!(text.contains("Ding"));
 890    });
 891}
 892
 893#[gpui::test]
 894async fn test_profiles(cx: &mut TestAppContext) {
 895    let ThreadTest {
 896        model, thread, fs, ..
 897    } = setup(cx, TestModel::Fake).await;
 898    let fake_model = model.as_fake();
 899
 900    thread.update(cx, |thread, _cx| {
 901        thread.add_tool(DelayTool);
 902        thread.add_tool(EchoTool);
 903        thread.add_tool(InfiniteTool);
 904    });
 905
 906    // Override profiles and wait for settings to be loaded.
 907    fs.insert_file(
 908        paths::settings_file(),
 909        json!({
 910            "agent": {
 911                "profiles": {
 912                    "test-1": {
 913                        "name": "Test Profile 1",
 914                        "tools": {
 915                            EchoTool::name(): true,
 916                            DelayTool::name(): true,
 917                        }
 918                    },
 919                    "test-2": {
 920                        "name": "Test Profile 2",
 921                        "tools": {
 922                            InfiniteTool::name(): true,
 923                        }
 924                    }
 925                }
 926            }
 927        })
 928        .to_string()
 929        .into_bytes(),
 930    )
 931    .await;
 932    cx.run_until_parked();
 933
 934    // Test that test-1 profile (default) has echo and delay tools
 935    thread
 936        .update(cx, |thread, cx| {
 937            thread.set_profile(AgentProfileId("test-1".into()), cx);
 938            thread.send(UserMessageId::new(), ["test"], cx)
 939        })
 940        .unwrap();
 941    cx.run_until_parked();
 942
 943    let mut pending_completions = fake_model.pending_completions();
 944    assert_eq!(pending_completions.len(), 1);
 945    let completion = pending_completions.pop().unwrap();
 946    let tool_names: Vec<String> = completion
 947        .tools
 948        .iter()
 949        .map(|tool| tool.name.clone())
 950        .collect();
 951    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 952    fake_model.end_last_completion_stream();
 953
 954    // Switch to test-2 profile, and verify that it has only the infinite tool.
 955    thread
 956        .update(cx, |thread, cx| {
 957            thread.set_profile(AgentProfileId("test-2".into()), cx);
 958            thread.send(UserMessageId::new(), ["test2"], cx)
 959        })
 960        .unwrap();
 961    cx.run_until_parked();
 962    let mut pending_completions = fake_model.pending_completions();
 963    assert_eq!(pending_completions.len(), 1);
 964    let completion = pending_completions.pop().unwrap();
 965    let tool_names: Vec<String> = completion
 966        .tools
 967        .iter()
 968        .map(|tool| tool.name.clone())
 969        .collect();
 970    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 971}
 972
 973#[gpui::test]
 974async fn test_mcp_tools(cx: &mut TestAppContext) {
 975    let ThreadTest {
 976        model,
 977        thread,
 978        context_server_store,
 979        fs,
 980        ..
 981    } = setup(cx, TestModel::Fake).await;
 982    let fake_model = model.as_fake();
 983
 984    // Override profiles and wait for settings to be loaded.
 985    fs.insert_file(
 986        paths::settings_file(),
 987        json!({
 988            "agent": {
 989                "always_allow_tool_actions": true,
 990                "profiles": {
 991                    "test": {
 992                        "name": "Test Profile",
 993                        "enable_all_context_servers": true,
 994                        "tools": {
 995                            EchoTool::name(): true,
 996                        }
 997                    },
 998                }
 999            }
1000        })
1001        .to_string()
1002        .into_bytes(),
1003    )
1004    .await;
1005    cx.run_until_parked();
1006    thread.update(cx, |thread, cx| {
1007        thread.set_profile(AgentProfileId("test".into()), cx)
1008    });
1009
1010    let mut mcp_tool_calls = setup_context_server(
1011        "test_server",
1012        vec![context_server::types::Tool {
1013            name: "echo".into(),
1014            description: None,
1015            input_schema: serde_json::to_value(EchoTool::input_schema(
1016                LanguageModelToolSchemaFormat::JsonSchema,
1017            ))
1018            .unwrap(),
1019            output_schema: None,
1020            annotations: None,
1021        }],
1022        &context_server_store,
1023        cx,
1024    );
1025
1026    let events = thread.update(cx, |thread, cx| {
1027        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1028    });
1029    cx.run_until_parked();
1030
1031    // Simulate the model calling the MCP tool.
1032    let completion = fake_model.pending_completions().pop().unwrap();
1033    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1034    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1035        LanguageModelToolUse {
1036            id: "tool_1".into(),
1037            name: "echo".into(),
1038            raw_input: json!({"text": "test"}).to_string(),
1039            input: json!({"text": "test"}),
1040            is_input_complete: true,
1041        },
1042    ));
1043    fake_model.end_last_completion_stream();
1044    cx.run_until_parked();
1045
1046    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1047    assert_eq!(tool_call_params.name, "echo");
1048    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1049    tool_call_response
1050        .send(context_server::types::CallToolResponse {
1051            content: vec![context_server::types::ToolResponseContent::Text {
1052                text: "test".into(),
1053            }],
1054            is_error: None,
1055            meta: None,
1056            structured_content: None,
1057        })
1058        .unwrap();
1059    cx.run_until_parked();
1060
1061    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1062    fake_model.send_last_completion_stream_text_chunk("Done!");
1063    fake_model.end_last_completion_stream();
1064    events.collect::<Vec<_>>().await;
1065
1066    // Send again after adding the echo tool, ensuring the name collision is resolved.
1067    let events = thread.update(cx, |thread, cx| {
1068        thread.add_tool(EchoTool);
1069        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1070    });
1071    cx.run_until_parked();
1072    let completion = fake_model.pending_completions().pop().unwrap();
1073    assert_eq!(
1074        tool_names_for_completion(&completion),
1075        vec!["echo", "test_server_echo"]
1076    );
1077    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1078        LanguageModelToolUse {
1079            id: "tool_2".into(),
1080            name: "test_server_echo".into(),
1081            raw_input: json!({"text": "mcp"}).to_string(),
1082            input: json!({"text": "mcp"}),
1083            is_input_complete: true,
1084        },
1085    ));
1086    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1087        LanguageModelToolUse {
1088            id: "tool_3".into(),
1089            name: "echo".into(),
1090            raw_input: json!({"text": "native"}).to_string(),
1091            input: json!({"text": "native"}),
1092            is_input_complete: true,
1093        },
1094    ));
1095    fake_model.end_last_completion_stream();
1096    cx.run_until_parked();
1097
1098    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1099    assert_eq!(tool_call_params.name, "echo");
1100    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1101    tool_call_response
1102        .send(context_server::types::CallToolResponse {
1103            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1104            is_error: None,
1105            meta: None,
1106            structured_content: None,
1107        })
1108        .unwrap();
1109    cx.run_until_parked();
1110
1111    // Ensure the tool results were inserted with the correct names.
1112    let completion = fake_model.pending_completions().pop().unwrap();
1113    assert_eq!(
1114        completion.messages.last().unwrap().content,
1115        vec![
1116            MessageContent::ToolResult(LanguageModelToolResult {
1117                tool_use_id: "tool_3".into(),
1118                tool_name: "echo".into(),
1119                is_error: false,
1120                content: "native".into(),
1121                output: Some("native".into()),
1122            },),
1123            MessageContent::ToolResult(LanguageModelToolResult {
1124                tool_use_id: "tool_2".into(),
1125                tool_name: "test_server_echo".into(),
1126                is_error: false,
1127                content: "mcp".into(),
1128                output: Some("mcp".into()),
1129            },),
1130        ]
1131    );
1132    fake_model.end_last_completion_stream();
1133    events.collect::<Vec<_>>().await;
1134}
1135
1136#[gpui::test]
1137async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1138    let ThreadTest {
1139        model,
1140        thread,
1141        context_server_store,
1142        fs,
1143        ..
1144    } = setup(cx, TestModel::Fake).await;
1145    let fake_model = model.as_fake();
1146
1147    // Set up a profile with all tools enabled
1148    fs.insert_file(
1149        paths::settings_file(),
1150        json!({
1151            "agent": {
1152                "profiles": {
1153                    "test": {
1154                        "name": "Test Profile",
1155                        "enable_all_context_servers": true,
1156                        "tools": {
1157                            EchoTool::name(): true,
1158                            DelayTool::name(): true,
1159                            WordListTool::name(): true,
1160                            ToolRequiringPermission::name(): true,
1161                            InfiniteTool::name(): true,
1162                        }
1163                    },
1164                }
1165            }
1166        })
1167        .to_string()
1168        .into_bytes(),
1169    )
1170    .await;
1171    cx.run_until_parked();
1172
1173    thread.update(cx, |thread, cx| {
1174        thread.set_profile(AgentProfileId("test".into()), cx);
1175        thread.add_tool(EchoTool);
1176        thread.add_tool(DelayTool);
1177        thread.add_tool(WordListTool);
1178        thread.add_tool(ToolRequiringPermission);
1179        thread.add_tool(InfiniteTool);
1180    });
1181
1182    // Set up multiple context servers with some overlapping tool names
1183    let _server1_calls = setup_context_server(
1184        "xxx",
1185        vec![
1186            context_server::types::Tool {
1187                name: "echo".into(), // Conflicts with native EchoTool
1188                description: None,
1189                input_schema: serde_json::to_value(EchoTool::input_schema(
1190                    LanguageModelToolSchemaFormat::JsonSchema,
1191                ))
1192                .unwrap(),
1193                output_schema: None,
1194                annotations: None,
1195            },
1196            context_server::types::Tool {
1197                name: "unique_tool_1".into(),
1198                description: None,
1199                input_schema: json!({"type": "object", "properties": {}}),
1200                output_schema: None,
1201                annotations: None,
1202            },
1203        ],
1204        &context_server_store,
1205        cx,
1206    );
1207
1208    let _server2_calls = setup_context_server(
1209        "yyy",
1210        vec![
1211            context_server::types::Tool {
1212                name: "echo".into(), // Also conflicts with native EchoTool
1213                description: None,
1214                input_schema: serde_json::to_value(EchoTool::input_schema(
1215                    LanguageModelToolSchemaFormat::JsonSchema,
1216                ))
1217                .unwrap(),
1218                output_schema: None,
1219                annotations: None,
1220            },
1221            context_server::types::Tool {
1222                name: "unique_tool_2".into(),
1223                description: None,
1224                input_schema: json!({"type": "object", "properties": {}}),
1225                output_schema: None,
1226                annotations: None,
1227            },
1228            context_server::types::Tool {
1229                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1230                description: None,
1231                input_schema: json!({"type": "object", "properties": {}}),
1232                output_schema: None,
1233                annotations: None,
1234            },
1235            context_server::types::Tool {
1236                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1237                description: None,
1238                input_schema: json!({"type": "object", "properties": {}}),
1239                output_schema: None,
1240                annotations: None,
1241            },
1242        ],
1243        &context_server_store,
1244        cx,
1245    );
1246    let _server3_calls = setup_context_server(
1247        "zzz",
1248        vec![
1249            context_server::types::Tool {
1250                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1251                description: None,
1252                input_schema: json!({"type": "object", "properties": {}}),
1253                output_schema: None,
1254                annotations: None,
1255            },
1256            context_server::types::Tool {
1257                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1258                description: None,
1259                input_schema: json!({"type": "object", "properties": {}}),
1260                output_schema: None,
1261                annotations: None,
1262            },
1263            context_server::types::Tool {
1264                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1265                description: None,
1266                input_schema: json!({"type": "object", "properties": {}}),
1267                output_schema: None,
1268                annotations: None,
1269            },
1270        ],
1271        &context_server_store,
1272        cx,
1273    );
1274
1275    thread
1276        .update(cx, |thread, cx| {
1277            thread.send(UserMessageId::new(), ["Go"], cx)
1278        })
1279        .unwrap();
1280    cx.run_until_parked();
1281    let completion = fake_model.pending_completions().pop().unwrap();
1282    assert_eq!(
1283        tool_names_for_completion(&completion),
1284        vec![
1285            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1286            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1287            "delay",
1288            "echo",
1289            "infinite",
1290            "tool_requiring_permission",
1291            "unique_tool_1",
1292            "unique_tool_2",
1293            "word_list",
1294            "xxx_echo",
1295            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1296            "yyy_echo",
1297            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1298        ]
1299    );
1300}
1301
1302#[gpui::test]
1303#[cfg_attr(not(feature = "e2e"), ignore)]
1304async fn test_cancellation(cx: &mut TestAppContext) {
1305    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1306
1307    let mut events = thread
1308        .update(cx, |thread, cx| {
1309            thread.add_tool(InfiniteTool);
1310            thread.add_tool(EchoTool);
1311            thread.send(
1312                UserMessageId::new(),
1313                ["Call the echo tool, then call the infinite tool, then explain their output"],
1314                cx,
1315            )
1316        })
1317        .unwrap();
1318
1319    // Wait until both tools are called.
1320    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1321    let mut echo_id = None;
1322    let mut echo_completed = false;
1323    while let Some(event) = events.next().await {
1324        match event.unwrap() {
1325            ThreadEvent::ToolCall(tool_call) => {
1326                assert_eq!(tool_call.title, expected_tools.remove(0));
1327                if tool_call.title == "Echo" {
1328                    echo_id = Some(tool_call.id);
1329                }
1330            }
1331            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1332                acp::ToolCallUpdate {
1333                    id,
1334                    fields:
1335                        acp::ToolCallUpdateFields {
1336                            status: Some(acp::ToolCallStatus::Completed),
1337                            ..
1338                        },
1339                    meta: None,
1340                },
1341            )) if Some(&id) == echo_id.as_ref() => {
1342                echo_completed = true;
1343            }
1344            _ => {}
1345        }
1346
1347        if expected_tools.is_empty() && echo_completed {
1348            break;
1349        }
1350    }
1351
1352    // Cancel the current send and ensure that the event stream is closed, even
1353    // if one of the tools is still running.
1354    thread.update(cx, |thread, cx| thread.cancel(cx));
1355    let events = events.collect::<Vec<_>>().await;
1356    let last_event = events.last();
1357    assert!(
1358        matches!(
1359            last_event,
1360            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1361        ),
1362        "unexpected event {last_event:?}"
1363    );
1364
1365    // Ensure we can still send a new message after cancellation.
1366    let events = thread
1367        .update(cx, |thread, cx| {
1368            thread.send(
1369                UserMessageId::new(),
1370                ["Testing: reply with 'Hello' then stop."],
1371                cx,
1372            )
1373        })
1374        .unwrap()
1375        .collect::<Vec<_>>()
1376        .await;
1377    thread.update(cx, |thread, _cx| {
1378        let message = thread.last_message().unwrap();
1379        let agent_message = message.as_agent_message().unwrap();
1380        assert_eq!(
1381            agent_message.content,
1382            vec![AgentMessageContent::Text("Hello".to_string())]
1383        );
1384    });
1385    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1386}
1387
1388#[gpui::test]
1389async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1390    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1391    let fake_model = model.as_fake();
1392
1393    let events_1 = thread
1394        .update(cx, |thread, cx| {
1395            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1396        })
1397        .unwrap();
1398    cx.run_until_parked();
1399    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1400    cx.run_until_parked();
1401
1402    let events_2 = thread
1403        .update(cx, |thread, cx| {
1404            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1405        })
1406        .unwrap();
1407    cx.run_until_parked();
1408    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1409    fake_model
1410        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1411    fake_model.end_last_completion_stream();
1412
1413    let events_1 = events_1.collect::<Vec<_>>().await;
1414    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1415    let events_2 = events_2.collect::<Vec<_>>().await;
1416    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1417}
1418
1419#[gpui::test]
1420async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1421    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1422    let fake_model = model.as_fake();
1423
1424    let events_1 = thread
1425        .update(cx, |thread, cx| {
1426            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1427        })
1428        .unwrap();
1429    cx.run_until_parked();
1430    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1431    fake_model
1432        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1433    fake_model.end_last_completion_stream();
1434    let events_1 = events_1.collect::<Vec<_>>().await;
1435
1436    let events_2 = thread
1437        .update(cx, |thread, cx| {
1438            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1439        })
1440        .unwrap();
1441    cx.run_until_parked();
1442    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1443    fake_model
1444        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1445    fake_model.end_last_completion_stream();
1446    let events_2 = events_2.collect::<Vec<_>>().await;
1447
1448    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1449    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1450}
1451
1452#[gpui::test]
1453async fn test_refusal(cx: &mut TestAppContext) {
1454    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1455    let fake_model = model.as_fake();
1456
1457    let events = thread
1458        .update(cx, |thread, cx| {
1459            thread.send(UserMessageId::new(), ["Hello"], cx)
1460        })
1461        .unwrap();
1462    cx.run_until_parked();
1463    thread.read_with(cx, |thread, _| {
1464        assert_eq!(
1465            thread.to_markdown(),
1466            indoc! {"
1467                ## User
1468
1469                Hello
1470            "}
1471        );
1472    });
1473
1474    fake_model.send_last_completion_stream_text_chunk("Hey!");
1475    cx.run_until_parked();
1476    thread.read_with(cx, |thread, _| {
1477        assert_eq!(
1478            thread.to_markdown(),
1479            indoc! {"
1480                ## User
1481
1482                Hello
1483
1484                ## Assistant
1485
1486                Hey!
1487            "}
1488        );
1489    });
1490
1491    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1492    fake_model
1493        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1494    let events = events.collect::<Vec<_>>().await;
1495    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1496    thread.read_with(cx, |thread, _| {
1497        assert_eq!(thread.to_markdown(), "");
1498    });
1499}
1500
1501#[gpui::test]
1502async fn test_truncate_first_message(cx: &mut TestAppContext) {
1503    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1504    let fake_model = model.as_fake();
1505
1506    let message_id = UserMessageId::new();
1507    thread
1508        .update(cx, |thread, cx| {
1509            thread.send(message_id.clone(), ["Hello"], cx)
1510        })
1511        .unwrap();
1512    cx.run_until_parked();
1513    thread.read_with(cx, |thread, _| {
1514        assert_eq!(
1515            thread.to_markdown(),
1516            indoc! {"
1517                ## User
1518
1519                Hello
1520            "}
1521        );
1522        assert_eq!(thread.latest_token_usage(), None);
1523    });
1524
1525    fake_model.send_last_completion_stream_text_chunk("Hey!");
1526    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1527        language_model::TokenUsage {
1528            input_tokens: 32_000,
1529            output_tokens: 16_000,
1530            cache_creation_input_tokens: 0,
1531            cache_read_input_tokens: 0,
1532        },
1533    ));
1534    cx.run_until_parked();
1535    thread.read_with(cx, |thread, _| {
1536        assert_eq!(
1537            thread.to_markdown(),
1538            indoc! {"
1539                ## User
1540
1541                Hello
1542
1543                ## Assistant
1544
1545                Hey!
1546            "}
1547        );
1548        assert_eq!(
1549            thread.latest_token_usage(),
1550            Some(acp_thread::TokenUsage {
1551                used_tokens: 32_000 + 16_000,
1552                max_tokens: 1_000_000,
1553            })
1554        );
1555    });
1556
1557    thread
1558        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1559        .unwrap();
1560    cx.run_until_parked();
1561    thread.read_with(cx, |thread, _| {
1562        assert_eq!(thread.to_markdown(), "");
1563        assert_eq!(thread.latest_token_usage(), None);
1564    });
1565
1566    // Ensure we can still send a new message after truncation.
1567    thread
1568        .update(cx, |thread, cx| {
1569            thread.send(UserMessageId::new(), ["Hi"], cx)
1570        })
1571        .unwrap();
1572    thread.update(cx, |thread, _cx| {
1573        assert_eq!(
1574            thread.to_markdown(),
1575            indoc! {"
1576                ## User
1577
1578                Hi
1579            "}
1580        );
1581    });
1582    cx.run_until_parked();
1583    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1584    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1585        language_model::TokenUsage {
1586            input_tokens: 40_000,
1587            output_tokens: 20_000,
1588            cache_creation_input_tokens: 0,
1589            cache_read_input_tokens: 0,
1590        },
1591    ));
1592    cx.run_until_parked();
1593    thread.read_with(cx, |thread, _| {
1594        assert_eq!(
1595            thread.to_markdown(),
1596            indoc! {"
1597                ## User
1598
1599                Hi
1600
1601                ## Assistant
1602
1603                Ahoy!
1604            "}
1605        );
1606
1607        assert_eq!(
1608            thread.latest_token_usage(),
1609            Some(acp_thread::TokenUsage {
1610                used_tokens: 40_000 + 20_000,
1611                max_tokens: 1_000_000,
1612            })
1613        );
1614    });
1615}
1616
1617#[gpui::test]
1618async fn test_truncate_second_message(cx: &mut TestAppContext) {
1619    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1620    let fake_model = model.as_fake();
1621
1622    thread
1623        .update(cx, |thread, cx| {
1624            thread.send(UserMessageId::new(), ["Message 1"], cx)
1625        })
1626        .unwrap();
1627    cx.run_until_parked();
1628    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1629    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1630        language_model::TokenUsage {
1631            input_tokens: 32_000,
1632            output_tokens: 16_000,
1633            cache_creation_input_tokens: 0,
1634            cache_read_input_tokens: 0,
1635        },
1636    ));
1637    fake_model.end_last_completion_stream();
1638    cx.run_until_parked();
1639
1640    let assert_first_message_state = |cx: &mut TestAppContext| {
1641        thread.clone().read_with(cx, |thread, _| {
1642            assert_eq!(
1643                thread.to_markdown(),
1644                indoc! {"
1645                    ## User
1646
1647                    Message 1
1648
1649                    ## Assistant
1650
1651                    Message 1 response
1652                "}
1653            );
1654
1655            assert_eq!(
1656                thread.latest_token_usage(),
1657                Some(acp_thread::TokenUsage {
1658                    used_tokens: 32_000 + 16_000,
1659                    max_tokens: 1_000_000,
1660                })
1661            );
1662        });
1663    };
1664
1665    assert_first_message_state(cx);
1666
1667    let second_message_id = UserMessageId::new();
1668    thread
1669        .update(cx, |thread, cx| {
1670            thread.send(second_message_id.clone(), ["Message 2"], cx)
1671        })
1672        .unwrap();
1673    cx.run_until_parked();
1674
1675    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1676    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1677        language_model::TokenUsage {
1678            input_tokens: 40_000,
1679            output_tokens: 20_000,
1680            cache_creation_input_tokens: 0,
1681            cache_read_input_tokens: 0,
1682        },
1683    ));
1684    fake_model.end_last_completion_stream();
1685    cx.run_until_parked();
1686
1687    thread.read_with(cx, |thread, _| {
1688        assert_eq!(
1689            thread.to_markdown(),
1690            indoc! {"
1691                ## User
1692
1693                Message 1
1694
1695                ## Assistant
1696
1697                Message 1 response
1698
1699                ## User
1700
1701                Message 2
1702
1703                ## Assistant
1704
1705                Message 2 response
1706            "}
1707        );
1708
1709        assert_eq!(
1710            thread.latest_token_usage(),
1711            Some(acp_thread::TokenUsage {
1712                used_tokens: 40_000 + 20_000,
1713                max_tokens: 1_000_000,
1714            })
1715        );
1716    });
1717
1718    thread
1719        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1720        .unwrap();
1721    cx.run_until_parked();
1722
1723    assert_first_message_state(cx);
1724}
1725
1726#[gpui::test]
1727async fn test_title_generation(cx: &mut TestAppContext) {
1728    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1729    let fake_model = model.as_fake();
1730
1731    let summary_model = Arc::new(FakeLanguageModel::default());
1732    thread.update(cx, |thread, cx| {
1733        thread.set_summarization_model(Some(summary_model.clone()), cx)
1734    });
1735
1736    let send = thread
1737        .update(cx, |thread, cx| {
1738            thread.send(UserMessageId::new(), ["Hello"], cx)
1739        })
1740        .unwrap();
1741    cx.run_until_parked();
1742
1743    fake_model.send_last_completion_stream_text_chunk("Hey!");
1744    fake_model.end_last_completion_stream();
1745    cx.run_until_parked();
1746    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1747
1748    // Ensure the summary model has been invoked to generate a title.
1749    summary_model.send_last_completion_stream_text_chunk("Hello ");
1750    summary_model.send_last_completion_stream_text_chunk("world\nG");
1751    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1752    summary_model.end_last_completion_stream();
1753    send.collect::<Vec<_>>().await;
1754    cx.run_until_parked();
1755    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1756
1757    // Send another message, ensuring no title is generated this time.
1758    let send = thread
1759        .update(cx, |thread, cx| {
1760            thread.send(UserMessageId::new(), ["Hello again"], cx)
1761        })
1762        .unwrap();
1763    cx.run_until_parked();
1764    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1765    fake_model.end_last_completion_stream();
1766    cx.run_until_parked();
1767    assert_eq!(summary_model.pending_completions(), Vec::new());
1768    send.collect::<Vec<_>>().await;
1769    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1770}
1771
1772#[gpui::test]
1773async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1774    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1775    let fake_model = model.as_fake();
1776
1777    let _events = thread
1778        .update(cx, |thread, cx| {
1779            thread.add_tool(ToolRequiringPermission);
1780            thread.add_tool(EchoTool);
1781            thread.send(UserMessageId::new(), ["Hey!"], cx)
1782        })
1783        .unwrap();
1784    cx.run_until_parked();
1785
1786    let permission_tool_use = LanguageModelToolUse {
1787        id: "tool_id_1".into(),
1788        name: ToolRequiringPermission::name().into(),
1789        raw_input: "{}".into(),
1790        input: json!({}),
1791        is_input_complete: true,
1792    };
1793    let echo_tool_use = LanguageModelToolUse {
1794        id: "tool_id_2".into(),
1795        name: EchoTool::name().into(),
1796        raw_input: json!({"text": "test"}).to_string(),
1797        input: json!({"text": "test"}),
1798        is_input_complete: true,
1799    };
1800    fake_model.send_last_completion_stream_text_chunk("Hi!");
1801    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1802        permission_tool_use,
1803    ));
1804    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1805        echo_tool_use.clone(),
1806    ));
1807    fake_model.end_last_completion_stream();
1808    cx.run_until_parked();
1809
1810    // Ensure pending tools are skipped when building a request.
1811    let request = thread
1812        .read_with(cx, |thread, cx| {
1813            thread.build_completion_request(CompletionIntent::EditFile, cx)
1814        })
1815        .unwrap();
1816    assert_eq!(
1817        request.messages[1..],
1818        vec![
1819            LanguageModelRequestMessage {
1820                role: Role::User,
1821                content: vec!["Hey!".into()],
1822                cache: true
1823            },
1824            LanguageModelRequestMessage {
1825                role: Role::Assistant,
1826                content: vec![
1827                    MessageContent::Text("Hi!".into()),
1828                    MessageContent::ToolUse(echo_tool_use.clone())
1829                ],
1830                cache: false
1831            },
1832            LanguageModelRequestMessage {
1833                role: Role::User,
1834                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1835                    tool_use_id: echo_tool_use.id.clone(),
1836                    tool_name: echo_tool_use.name,
1837                    is_error: false,
1838                    content: "test".into(),
1839                    output: Some("test".into())
1840                })],
1841                cache: false
1842            },
1843        ],
1844    );
1845}
1846
1847#[gpui::test]
1848async fn test_agent_connection(cx: &mut TestAppContext) {
1849    cx.update(settings::init);
1850    let templates = Templates::new();
1851
1852    // Initialize language model system with test provider
1853    cx.update(|cx| {
1854        gpui_tokio::init(cx);
1855
1856        let http_client = FakeHttpClient::with_404_response();
1857        let clock = Arc::new(clock::FakeSystemClock::new());
1858        let client = Client::new(clock, http_client, cx);
1859        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1860        language_model::init(client.clone(), cx);
1861        language_models::init(user_store, client.clone(), cx);
1862        LanguageModelRegistry::test(cx);
1863    });
1864    cx.executor().forbid_parking();
1865
1866    // Create a project for new_thread
1867    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1868    fake_fs.insert_tree(path!("/test"), json!({})).await;
1869    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1870    let cwd = Path::new("/test");
1871    let text_thread_store =
1872        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1873    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1874
1875    // Create agent and connection
1876    let agent = NativeAgent::new(
1877        project.clone(),
1878        history_store,
1879        templates.clone(),
1880        None,
1881        fake_fs.clone(),
1882        &mut cx.to_async(),
1883    )
1884    .await
1885    .unwrap();
1886    let connection = NativeAgentConnection(agent.clone());
1887
1888    // Create a thread using new_thread
1889    let connection_rc = Rc::new(connection.clone());
1890    let acp_thread = cx
1891        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1892        .await
1893        .expect("new_thread should succeed");
1894
1895    // Get the session_id from the AcpThread
1896    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1897
1898    // Test model_selector returns Some
1899    let selector_opt = connection.model_selector(&session_id);
1900    assert!(
1901        selector_opt.is_some(),
1902        "agent should always support ModelSelector"
1903    );
1904    let selector = selector_opt.unwrap();
1905
1906    // Test list_models
1907    let listed_models = cx
1908        .update(|cx| selector.list_models(cx))
1909        .await
1910        .expect("list_models should succeed");
1911    let AgentModelList::Grouped(listed_models) = listed_models else {
1912        panic!("Unexpected model list type");
1913    };
1914    assert!(!listed_models.is_empty(), "should have at least one model");
1915    assert_eq!(
1916        listed_models[&AgentModelGroupName("Fake".into())][0]
1917            .id
1918            .0
1919            .as_ref(),
1920        "fake/fake"
1921    );
1922
1923    // Test selected_model returns the default
1924    let model = cx
1925        .update(|cx| selector.selected_model(cx))
1926        .await
1927        .expect("selected_model should succeed");
1928    let model = cx
1929        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1930        .unwrap();
1931    let model = model.as_fake();
1932    assert_eq!(model.id().0, "fake", "should return default model");
1933
1934    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1935    cx.run_until_parked();
1936    model.send_last_completion_stream_text_chunk("def");
1937    cx.run_until_parked();
1938    acp_thread.read_with(cx, |thread, cx| {
1939        assert_eq!(
1940            thread.to_markdown(cx),
1941            indoc! {"
1942                ## User
1943
1944                abc
1945
1946                ## Assistant
1947
1948                def
1949
1950            "}
1951        )
1952    });
1953
1954    // Test cancel
1955    cx.update(|cx| connection.cancel(&session_id, cx));
1956    request.await.expect("prompt should fail gracefully");
1957
1958    // Ensure that dropping the ACP thread causes the native thread to be
1959    // dropped as well.
1960    cx.update(|_| drop(acp_thread));
1961    let result = cx
1962        .update(|cx| {
1963            connection.prompt(
1964                Some(acp_thread::UserMessageId::new()),
1965                acp::PromptRequest {
1966                    session_id: session_id.clone(),
1967                    prompt: vec!["ghi".into()],
1968                    meta: None,
1969                },
1970                cx,
1971            )
1972        })
1973        .await;
1974    assert_eq!(
1975        result.as_ref().unwrap_err().to_string(),
1976        "Session not found",
1977        "unexpected result: {:?}",
1978        result
1979    );
1980}
1981
1982#[gpui::test]
1983async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1984    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1985    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1986    let fake_model = model.as_fake();
1987
1988    let mut events = thread
1989        .update(cx, |thread, cx| {
1990            thread.send(UserMessageId::new(), ["Think"], cx)
1991        })
1992        .unwrap();
1993    cx.run_until_parked();
1994
1995    // Simulate streaming partial input.
1996    let input = json!({});
1997    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1998        LanguageModelToolUse {
1999            id: "1".into(),
2000            name: ThinkingTool::name().into(),
2001            raw_input: input.to_string(),
2002            input,
2003            is_input_complete: false,
2004        },
2005    ));
2006
2007    // Input streaming completed
2008    let input = json!({ "content": "Thinking hard!" });
2009    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2010        LanguageModelToolUse {
2011            id: "1".into(),
2012            name: "thinking".into(),
2013            raw_input: input.to_string(),
2014            input,
2015            is_input_complete: true,
2016        },
2017    ));
2018    fake_model.end_last_completion_stream();
2019    cx.run_until_parked();
2020
2021    let tool_call = expect_tool_call(&mut events).await;
2022    assert_eq!(
2023        tool_call,
2024        acp::ToolCall {
2025            id: acp::ToolCallId("1".into()),
2026            title: "Thinking".into(),
2027            kind: acp::ToolKind::Think,
2028            status: acp::ToolCallStatus::Pending,
2029            content: vec![],
2030            locations: vec![],
2031            raw_input: Some(json!({})),
2032            raw_output: None,
2033            meta: Some(json!({ "tool_name": "thinking" })),
2034        }
2035    );
2036    let update = expect_tool_call_update_fields(&mut events).await;
2037    assert_eq!(
2038        update,
2039        acp::ToolCallUpdate {
2040            id: acp::ToolCallId("1".into()),
2041            fields: acp::ToolCallUpdateFields {
2042                title: Some("Thinking".into()),
2043                kind: Some(acp::ToolKind::Think),
2044                raw_input: Some(json!({ "content": "Thinking hard!" })),
2045                ..Default::default()
2046            },
2047            meta: None,
2048        }
2049    );
2050    let update = expect_tool_call_update_fields(&mut events).await;
2051    assert_eq!(
2052        update,
2053        acp::ToolCallUpdate {
2054            id: acp::ToolCallId("1".into()),
2055            fields: acp::ToolCallUpdateFields {
2056                status: Some(acp::ToolCallStatus::InProgress),
2057                ..Default::default()
2058            },
2059            meta: None,
2060        }
2061    );
2062    let update = expect_tool_call_update_fields(&mut events).await;
2063    assert_eq!(
2064        update,
2065        acp::ToolCallUpdate {
2066            id: acp::ToolCallId("1".into()),
2067            fields: acp::ToolCallUpdateFields {
2068                content: Some(vec!["Thinking hard!".into()]),
2069                ..Default::default()
2070            },
2071            meta: None,
2072        }
2073    );
2074    let update = expect_tool_call_update_fields(&mut events).await;
2075    assert_eq!(
2076        update,
2077        acp::ToolCallUpdate {
2078            id: acp::ToolCallId("1".into()),
2079            fields: acp::ToolCallUpdateFields {
2080                status: Some(acp::ToolCallStatus::Completed),
2081                raw_output: Some("Finished thinking.".into()),
2082                ..Default::default()
2083            },
2084            meta: None,
2085        }
2086    );
2087}
2088
2089#[gpui::test]
2090async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2091    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2092    let fake_model = model.as_fake();
2093
2094    let mut events = thread
2095        .update(cx, |thread, cx| {
2096            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2097            thread.send(UserMessageId::new(), ["Hello!"], cx)
2098        })
2099        .unwrap();
2100    cx.run_until_parked();
2101
2102    fake_model.send_last_completion_stream_text_chunk("Hey!");
2103    fake_model.end_last_completion_stream();
2104
2105    let mut retry_events = Vec::new();
2106    while let Some(Ok(event)) = events.next().await {
2107        match event {
2108            ThreadEvent::Retry(retry_status) => {
2109                retry_events.push(retry_status);
2110            }
2111            ThreadEvent::Stop(..) => break,
2112            _ => {}
2113        }
2114    }
2115
2116    assert_eq!(retry_events.len(), 0);
2117    thread.read_with(cx, |thread, _cx| {
2118        assert_eq!(
2119            thread.to_markdown(),
2120            indoc! {"
2121                ## User
2122
2123                Hello!
2124
2125                ## Assistant
2126
2127                Hey!
2128            "}
2129        )
2130    });
2131}
2132
2133#[gpui::test]
2134async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2135    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2136    let fake_model = model.as_fake();
2137
2138    let mut events = thread
2139        .update(cx, |thread, cx| {
2140            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2141            thread.send(UserMessageId::new(), ["Hello!"], cx)
2142        })
2143        .unwrap();
2144    cx.run_until_parked();
2145
2146    fake_model.send_last_completion_stream_text_chunk("Hey,");
2147    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2148        provider: LanguageModelProviderName::new("Anthropic"),
2149        retry_after: Some(Duration::from_secs(3)),
2150    });
2151    fake_model.end_last_completion_stream();
2152
2153    cx.executor().advance_clock(Duration::from_secs(3));
2154    cx.run_until_parked();
2155
2156    fake_model.send_last_completion_stream_text_chunk("there!");
2157    fake_model.end_last_completion_stream();
2158    cx.run_until_parked();
2159
2160    let mut retry_events = Vec::new();
2161    while let Some(Ok(event)) = events.next().await {
2162        match event {
2163            ThreadEvent::Retry(retry_status) => {
2164                retry_events.push(retry_status);
2165            }
2166            ThreadEvent::Stop(..) => break,
2167            _ => {}
2168        }
2169    }
2170
2171    assert_eq!(retry_events.len(), 1);
2172    assert!(matches!(
2173        retry_events[0],
2174        acp_thread::RetryStatus { attempt: 1, .. }
2175    ));
2176    thread.read_with(cx, |thread, _cx| {
2177        assert_eq!(
2178            thread.to_markdown(),
2179            indoc! {"
2180                ## User
2181
2182                Hello!
2183
2184                ## Assistant
2185
2186                Hey,
2187
2188                [resume]
2189
2190                ## Assistant
2191
2192                there!
2193            "}
2194        )
2195    });
2196}
2197
2198#[gpui::test]
2199async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2200    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2201    let fake_model = model.as_fake();
2202
2203    let events = thread
2204        .update(cx, |thread, cx| {
2205            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2206            thread.add_tool(EchoTool);
2207            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2208        })
2209        .unwrap();
2210    cx.run_until_parked();
2211
2212    let tool_use_1 = LanguageModelToolUse {
2213        id: "tool_1".into(),
2214        name: EchoTool::name().into(),
2215        raw_input: json!({"text": "test"}).to_string(),
2216        input: json!({"text": "test"}),
2217        is_input_complete: true,
2218    };
2219    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2220        tool_use_1.clone(),
2221    ));
2222    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2223        provider: LanguageModelProviderName::new("Anthropic"),
2224        retry_after: Some(Duration::from_secs(3)),
2225    });
2226    fake_model.end_last_completion_stream();
2227
2228    cx.executor().advance_clock(Duration::from_secs(3));
2229    let completion = fake_model.pending_completions().pop().unwrap();
2230    assert_eq!(
2231        completion.messages[1..],
2232        vec![
2233            LanguageModelRequestMessage {
2234                role: Role::User,
2235                content: vec!["Call the echo tool!".into()],
2236                cache: false
2237            },
2238            LanguageModelRequestMessage {
2239                role: Role::Assistant,
2240                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2241                cache: false
2242            },
2243            LanguageModelRequestMessage {
2244                role: Role::User,
2245                content: vec![language_model::MessageContent::ToolResult(
2246                    LanguageModelToolResult {
2247                        tool_use_id: tool_use_1.id.clone(),
2248                        tool_name: tool_use_1.name.clone(),
2249                        is_error: false,
2250                        content: "test".into(),
2251                        output: Some("test".into())
2252                    }
2253                )],
2254                cache: true
2255            },
2256        ]
2257    );
2258
2259    fake_model.send_last_completion_stream_text_chunk("Done");
2260    fake_model.end_last_completion_stream();
2261    cx.run_until_parked();
2262    events.collect::<Vec<_>>().await;
2263    thread.read_with(cx, |thread, _cx| {
2264        assert_eq!(
2265            thread.last_message(),
2266            Some(Message::Agent(AgentMessage {
2267                content: vec![AgentMessageContent::Text("Done".into())],
2268                tool_results: IndexMap::default()
2269            }))
2270        );
2271    })
2272}
2273
2274#[gpui::test]
2275async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2276    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2277    let fake_model = model.as_fake();
2278
2279    let mut events = thread
2280        .update(cx, |thread, cx| {
2281            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2282            thread.send(UserMessageId::new(), ["Hello!"], cx)
2283        })
2284        .unwrap();
2285    cx.run_until_parked();
2286
2287    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2288        fake_model.send_last_completion_stream_error(
2289            LanguageModelCompletionError::ServerOverloaded {
2290                provider: LanguageModelProviderName::new("Anthropic"),
2291                retry_after: Some(Duration::from_secs(3)),
2292            },
2293        );
2294        fake_model.end_last_completion_stream();
2295        cx.executor().advance_clock(Duration::from_secs(3));
2296        cx.run_until_parked();
2297    }
2298
2299    let mut errors = Vec::new();
2300    let mut retry_events = Vec::new();
2301    while let Some(event) = events.next().await {
2302        match event {
2303            Ok(ThreadEvent::Retry(retry_status)) => {
2304                retry_events.push(retry_status);
2305            }
2306            Ok(ThreadEvent::Stop(..)) => break,
2307            Err(error) => errors.push(error),
2308            _ => {}
2309        }
2310    }
2311
2312    assert_eq!(
2313        retry_events.len(),
2314        crate::thread::MAX_RETRY_ATTEMPTS as usize
2315    );
2316    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2317        assert_eq!(retry_events[i].attempt, i + 1);
2318    }
2319    assert_eq!(errors.len(), 1);
2320    let error = errors[0]
2321        .downcast_ref::<LanguageModelCompletionError>()
2322        .unwrap();
2323    assert!(matches!(
2324        error,
2325        LanguageModelCompletionError::ServerOverloaded { .. }
2326    ));
2327}
2328
2329/// Filters out the stop events for asserting against in tests
2330fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2331    result_events
2332        .into_iter()
2333        .filter_map(|event| match event.unwrap() {
2334            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2335            _ => None,
2336        })
2337        .collect()
2338}
2339
2340struct ThreadTest {
2341    model: Arc<dyn LanguageModel>,
2342    thread: Entity<Thread>,
2343    project_context: Entity<ProjectContext>,
2344    context_server_store: Entity<ContextServerStore>,
2345    fs: Arc<FakeFs>,
2346}
2347
2348enum TestModel {
2349    Sonnet4,
2350    Fake,
2351}
2352
2353impl TestModel {
2354    fn id(&self) -> LanguageModelId {
2355        match self {
2356            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2357            TestModel::Fake => unreachable!(),
2358        }
2359    }
2360}
2361
2362async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2363    cx.executor().allow_parking();
2364
2365    let fs = FakeFs::new(cx.background_executor.clone());
2366    fs.create_dir(paths::settings_file().parent().unwrap())
2367        .await
2368        .unwrap();
2369    fs.insert_file(
2370        paths::settings_file(),
2371        json!({
2372            "agent": {
2373                "default_profile": "test-profile",
2374                "profiles": {
2375                    "test-profile": {
2376                        "name": "Test Profile",
2377                        "tools": {
2378                            EchoTool::name(): true,
2379                            DelayTool::name(): true,
2380                            WordListTool::name(): true,
2381                            ToolRequiringPermission::name(): true,
2382                            InfiniteTool::name(): true,
2383                            ThinkingTool::name(): true,
2384                        }
2385                    }
2386                }
2387            }
2388        })
2389        .to_string()
2390        .into_bytes(),
2391    )
2392    .await;
2393
2394    cx.update(|cx| {
2395        settings::init(cx);
2396
2397        match model {
2398            TestModel::Fake => {}
2399            TestModel::Sonnet4 => {
2400                gpui_tokio::init(cx);
2401                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2402                cx.set_http_client(Arc::new(http_client));
2403                let client = Client::production(cx);
2404                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2405                language_model::init(client.clone(), cx);
2406                language_models::init(user_store, client.clone(), cx);
2407            }
2408        };
2409
2410        watch_settings(fs.clone(), cx);
2411    });
2412
2413    let templates = Templates::new();
2414
2415    fs.insert_tree(path!("/test"), json!({})).await;
2416    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2417
2418    let model = cx
2419        .update(|cx| {
2420            if let TestModel::Fake = model {
2421                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2422            } else {
2423                let model_id = model.id();
2424                let models = LanguageModelRegistry::read_global(cx);
2425                let model = models
2426                    .available_models(cx)
2427                    .find(|model| model.id() == model_id)
2428                    .unwrap();
2429
2430                let provider = models.provider(&model.provider_id()).unwrap();
2431                let authenticated = provider.authenticate(cx);
2432
2433                cx.spawn(async move |_cx| {
2434                    authenticated.await.unwrap();
2435                    model
2436                })
2437            }
2438        })
2439        .await;
2440
2441    let project_context = cx.new(|_cx| ProjectContext::default());
2442    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2443    let context_server_registry =
2444        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2445    let thread = cx.new(|cx| {
2446        Thread::new(
2447            project,
2448            project_context.clone(),
2449            context_server_registry,
2450            templates,
2451            Some(model.clone()),
2452            cx,
2453        )
2454    });
2455    ThreadTest {
2456        model,
2457        thread,
2458        project_context,
2459        context_server_store,
2460        fs,
2461    }
2462}
2463
2464#[cfg(test)]
2465#[ctor::ctor]
2466fn init_logger() {
2467    if std::env::var("RUST_LOG").is_ok() {
2468        env_logger::init();
2469    }
2470}
2471
2472fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2473    let fs = fs.clone();
2474    cx.spawn({
2475        async move |cx| {
2476            let mut new_settings_content_rx = settings::watch_config_file(
2477                cx.background_executor(),
2478                fs,
2479                paths::settings_file().clone(),
2480            );
2481
2482            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2483                cx.update(|cx| {
2484                    SettingsStore::update_global(cx, |settings, cx| {
2485                        settings.set_user_settings(&new_settings_content, cx)
2486                    })
2487                })
2488                .ok();
2489            }
2490        }
2491    })
2492    .detach();
2493}
2494
2495fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2496    completion
2497        .tools
2498        .iter()
2499        .map(|tool| tool.name.clone())
2500        .collect()
2501}
2502
2503fn setup_context_server(
2504    name: &'static str,
2505    tools: Vec<context_server::types::Tool>,
2506    context_server_store: &Entity<ContextServerStore>,
2507    cx: &mut TestAppContext,
2508) -> mpsc::UnboundedReceiver<(
2509    context_server::types::CallToolParams,
2510    oneshot::Sender<context_server::types::CallToolResponse>,
2511)> {
2512    cx.update(|cx| {
2513        let mut settings = ProjectSettings::get_global(cx).clone();
2514        settings.context_servers.insert(
2515            name.into(),
2516            project::project_settings::ContextServerSettings::Custom {
2517                enabled: true,
2518                command: ContextServerCommand {
2519                    path: "somebinary".into(),
2520                    args: Vec::new(),
2521                    env: None,
2522                    timeout: None,
2523                },
2524            },
2525        );
2526        ProjectSettings::override_global(settings, cx);
2527    });
2528
2529    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2530    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2531        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2532            context_server::types::InitializeResponse {
2533                protocol_version: context_server::types::ProtocolVersion(
2534                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2535                ),
2536                server_info: context_server::types::Implementation {
2537                    name: name.into(),
2538                    version: "1.0.0".to_string(),
2539                },
2540                capabilities: context_server::types::ServerCapabilities {
2541                    tools: Some(context_server::types::ToolsCapabilities {
2542                        list_changed: Some(true),
2543                    }),
2544                    ..Default::default()
2545                },
2546                meta: None,
2547            }
2548        })
2549        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2550            let tools = tools.clone();
2551            async move {
2552                context_server::types::ListToolsResponse {
2553                    tools,
2554                    next_cursor: None,
2555                    meta: None,
2556                }
2557            }
2558        })
2559        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2560            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2561            async move {
2562                let (response_tx, response_rx) = oneshot::channel();
2563                mcp_tool_calls_tx
2564                    .unbounded_send((params, response_tx))
2565                    .unwrap();
2566                response_rx.await.unwrap()
2567            }
2568        });
2569    context_server_store.update(cx, |store, cx| {
2570        store.start_server(
2571            Arc::new(ContextServer::new(
2572                ContextServerId(name.into()),
2573                Arc::new(fake_transport),
2574            )),
2575            cx,
2576        );
2577    });
2578    cx.run_until_parked();
2579    mcp_tool_calls_rx
2580}
2581
2582/// Tests that rules toggled after thread creation are applied to the first message.
2583///
2584/// This test verifies the fix for https://github.com/zed-industries/zed/issues/39057
2585/// where rules toggled in the Rules Library after opening a new agent thread were not
2586/// being applied to the first message sent in that thread.
2587///
2588/// The test simulates:
2589/// 1. Creating a thread with an initial rule in the project context
2590/// 2. Updating the project context entity with new rules (simulating what
2591///    NativeAgent.maintain_project_context does when rules are toggled)
2592/// 3. Sending a message through the thread
2593/// 4. Verifying that the newly toggled rules appear in the system prompt
2594///
2595/// The fix ensures that threads see updated rules by updating the project_context
2596/// entity in place rather than creating a new entity that threads wouldn't reference.
2597#[gpui::test]
2598async fn test_rules_toggled_after_thread_creation_are_applied(cx: &mut TestAppContext) {
2599    let ThreadTest {
2600        model,
2601        thread,
2602        project_context,
2603        ..
2604    } = setup(cx, TestModel::Fake).await;
2605    let fake_model = model.as_fake();
2606
2607    let initial_rule_content = "Initial rule before thread creation.";
2608    project_context.update(cx, |context, _cx| {
2609        context.user_rules.push(UserRulesContext {
2610            uuid: UserPromptId::new(),
2611            title: Some("Initial Rule".to_string()),
2612            contents: initial_rule_content.to_string(),
2613        });
2614        context.has_user_rules = true;
2615    });
2616
2617    let rule_id = UserPromptId::new();
2618    let new_rule_content = "Always respond in uppercase.";
2619
2620    project_context.update(cx, |context, _cx| {
2621        context.user_rules.clear();
2622        context.user_rules.push(UserRulesContext {
2623            uuid: rule_id,
2624            title: Some("New Rule".to_string()),
2625            contents: new_rule_content.to_string(),
2626        });
2627        context.has_user_rules = true;
2628    });
2629
2630    thread
2631        .update(cx, |thread, cx| {
2632            thread.send(UserMessageId::new(), ["test message"], cx)
2633        })
2634        .unwrap();
2635
2636    cx.run_until_parked();
2637
2638    let mut pending_completions = fake_model.pending_completions();
2639    assert_eq!(pending_completions.len(), 1);
2640
2641    let pending_completion = pending_completions.pop().unwrap();
2642    let system_message = &pending_completion.messages[0];
2643    let system_prompt = system_message.content[0].to_str().unwrap();
2644
2645    assert!(
2646        system_prompt.contains(new_rule_content),
2647        "System prompt should contain the rule content that was toggled after thread creation"
2648    );
2649}
2650
2651/// Verifies the fix for issue #39057: entity replacement breaks rule updates.
2652///
2653/// This test will FAIL without the fix and PASS with the fix.
2654///
2655/// The buggy code in maintain_project_context creates a NEW entity:
2656///   `this.project_context = cx.new(|_| new_data);`
2657/// The fixed code updates the EXISTING entity in place:
2658///   `this.project_context.update(cx, |ctx, _| *ctx = new_data);`
2659///
2660/// This test simulates the bug by:
2661/// 1. Creating a thread (thread holds reference to project_context entity)
2662/// 2. Creating NEW project context data with rules (simulating rule toggle)
2663/// 3. Simulating the BUGGY behavior: creating a completely new entity
2664/// 4. Sending a message (thread still references the old entity)
2665/// 5. Asserting rules are present (FAILS with bug, PASSES with fix)
2666///
2667/// With the fix, maintain_project_context updates in place, so we simulate that here.
2668#[gpui::test]
2669async fn test_project_context_entity_updated_in_place(cx: &mut TestAppContext) {
2670    let ThreadTest {
2671        model,
2672        thread,
2673        project_context,
2674        ..
2675    } = setup(cx, TestModel::Fake).await;
2676    let fake_model = model.as_fake();
2677
2678    let rule_content = "Rule toggled after thread creation.";
2679
2680    // Simulate what maintain_project_context does: build new context data with updated rules
2681    let new_context_data = {
2682        let mut context = ProjectContext::default();
2683        context.user_rules.push(UserRulesContext {
2684            uuid: UserPromptId::new(),
2685            title: Some("Toggled Rule".to_string()),
2686            contents: rule_content.to_string(),
2687        });
2688        context.has_user_rules = true;
2689        context
2690    };
2691
2692    // THE FIX: Update the existing entity in place (not creating a new entity)
2693    // This is what the fixed maintain_project_context does.
2694    // The buggy version would do: this.project_context = cx.new(|_| new_context_data);
2695    // which would leave the thread with a stale reference.
2696    project_context.update(cx, |context, _cx| {
2697        *context = new_context_data;
2698    });
2699
2700    thread
2701        .update(cx, |thread, cx| {
2702            thread.send(UserMessageId::new(), ["test message"], cx)
2703        })
2704        .unwrap();
2705
2706    cx.run_until_parked();
2707
2708    let mut pending_completions = fake_model.pending_completions();
2709    assert_eq!(pending_completions.len(), 1);
2710
2711    let pending_completion = pending_completions.pop().unwrap();
2712    let system_message = &pending_completion.messages[0];
2713    let system_prompt = system_message.content[0].to_str().unwrap();
2714
2715    assert!(
2716        system_prompt.contains(rule_content),
2717        "Thread should see rules because entity was updated in place.\n\
2718         Without the fix, maintain_project_context would create a NEW entity,\n\
2719         leaving the thread with a reference to the old entity (which has no rules).\n\
2720         This test passes because we simulate the FIXED behavior (update in place)."
2721    );
2722}