mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
  76async fn test_thinking(cx: &mut TestAppContext) {
  77    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  78    let fake_model = model.as_fake();
  79
  80    let events = thread
  81        .update(cx, |thread, cx| {
  82            thread.send(
  83                UserMessageId::new(),
  84                [indoc! {"
  85                    Testing:
  86
  87                    Generate a thinking step where you just think the word 'Think',
  88                    and have your final answer be 'Hello'
  89                "}],
  90                cx,
  91            )
  92        })
  93        .unwrap();
  94    cx.run_until_parked();
  95    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  96        text: "Think".to_string(),
  97        signature: None,
  98    });
  99    fake_model.send_last_completion_stream_text_chunk("Hello");
 100    fake_model
 101        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 102    fake_model.end_last_completion_stream();
 103
 104    let events = events.collect().await;
 105    thread.update(cx, |thread, _cx| {
 106        assert_eq!(
 107            thread.last_message().unwrap().to_markdown(),
 108            indoc! {"
 109                ## Assistant
 110
 111                <think>Think</think>
 112                Hello
 113            "}
 114        )
 115    });
 116    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 117}
 118
 119#[gpui::test]
 120async fn test_system_prompt(cx: &mut TestAppContext) {
 121    let ThreadTest {
 122        model,
 123        thread,
 124        project_context,
 125        ..
 126    } = setup(cx, TestModel::Fake).await;
 127    let fake_model = model.as_fake();
 128
 129    project_context.update(cx, |project_context, _cx| {
 130        project_context.shell = "test-shell".into()
 131    });
 132    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 133    thread
 134        .update(cx, |thread, cx| {
 135            thread.send(UserMessageId::new(), ["abc"], cx)
 136        })
 137        .unwrap();
 138    cx.run_until_parked();
 139    let mut pending_completions = fake_model.pending_completions();
 140    assert_eq!(
 141        pending_completions.len(),
 142        1,
 143        "unexpected pending completions: {:?}",
 144        pending_completions
 145    );
 146
 147    let pending_completion = pending_completions.pop().unwrap();
 148    assert_eq!(pending_completion.messages[0].role, Role::System);
 149
 150    let system_message = &pending_completion.messages[0];
 151    let system_prompt = system_message.content[0].to_str().unwrap();
 152    assert!(
 153        system_prompt.contains("test-shell"),
 154        "unexpected system message: {:?}",
 155        system_message
 156    );
 157    assert!(
 158        system_prompt.contains("## Fixing Diagnostics"),
 159        "unexpected system message: {:?}",
 160        system_message
 161    );
 162}
 163
 164#[gpui::test]
 165async fn test_prompt_caching(cx: &mut TestAppContext) {
 166    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 167    let fake_model = model.as_fake();
 168
 169    // Send initial user message and verify it's cached
 170    thread
 171        .update(cx, |thread, cx| {
 172            thread.send(UserMessageId::new(), ["Message 1"], cx)
 173        })
 174        .unwrap();
 175    cx.run_until_parked();
 176
 177    let completion = fake_model.pending_completions().pop().unwrap();
 178    assert_eq!(
 179        completion.messages[1..],
 180        vec![LanguageModelRequestMessage {
 181            role: Role::User,
 182            content: vec!["Message 1".into()],
 183            cache: true
 184        }]
 185    );
 186    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 187        "Response to Message 1".into(),
 188    ));
 189    fake_model.end_last_completion_stream();
 190    cx.run_until_parked();
 191
 192    // Send another user message and verify only the latest is cached
 193    thread
 194        .update(cx, |thread, cx| {
 195            thread.send(UserMessageId::new(), ["Message 2"], cx)
 196        })
 197        .unwrap();
 198    cx.run_until_parked();
 199
 200    let completion = fake_model.pending_completions().pop().unwrap();
 201    assert_eq!(
 202        completion.messages[1..],
 203        vec![
 204            LanguageModelRequestMessage {
 205                role: Role::User,
 206                content: vec!["Message 1".into()],
 207                cache: false
 208            },
 209            LanguageModelRequestMessage {
 210                role: Role::Assistant,
 211                content: vec!["Response to Message 1".into()],
 212                cache: false
 213            },
 214            LanguageModelRequestMessage {
 215                role: Role::User,
 216                content: vec!["Message 2".into()],
 217                cache: true
 218            }
 219        ]
 220    );
 221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 222        "Response to Message 2".into(),
 223    ));
 224    fake_model.end_last_completion_stream();
 225    cx.run_until_parked();
 226
 227    // Simulate a tool call and verify that the latest tool result is cached
 228    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 229    thread
 230        .update(cx, |thread, cx| {
 231            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 232        })
 233        .unwrap();
 234    cx.run_until_parked();
 235
 236    let tool_use = LanguageModelToolUse {
 237        id: "tool_1".into(),
 238        name: EchoTool::name().into(),
 239        raw_input: json!({"text": "test"}).to_string(),
 240        input: json!({"text": "test"}),
 241        is_input_complete: true,
 242    };
 243    fake_model
 244        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 245    fake_model.end_last_completion_stream();
 246    cx.run_until_parked();
 247
 248    let completion = fake_model.pending_completions().pop().unwrap();
 249    let tool_result = LanguageModelToolResult {
 250        tool_use_id: "tool_1".into(),
 251        tool_name: EchoTool::name().into(),
 252        is_error: false,
 253        content: "test".into(),
 254        output: Some("test".into()),
 255    };
 256    assert_eq!(
 257        completion.messages[1..],
 258        vec![
 259            LanguageModelRequestMessage {
 260                role: Role::User,
 261                content: vec!["Message 1".into()],
 262                cache: false
 263            },
 264            LanguageModelRequestMessage {
 265                role: Role::Assistant,
 266                content: vec!["Response to Message 1".into()],
 267                cache: false
 268            },
 269            LanguageModelRequestMessage {
 270                role: Role::User,
 271                content: vec!["Message 2".into()],
 272                cache: false
 273            },
 274            LanguageModelRequestMessage {
 275                role: Role::Assistant,
 276                content: vec!["Response to Message 2".into()],
 277                cache: false
 278            },
 279            LanguageModelRequestMessage {
 280                role: Role::User,
 281                content: vec!["Use the echo tool".into()],
 282                cache: false
 283            },
 284            LanguageModelRequestMessage {
 285                role: Role::Assistant,
 286                content: vec![MessageContent::ToolUse(tool_use)],
 287                cache: false
 288            },
 289            LanguageModelRequestMessage {
 290                role: Role::User,
 291                content: vec![MessageContent::ToolResult(tool_result)],
 292                cache: true
 293            }
 294        ]
 295    );
 296}
 297
 298#[gpui::test]
 299#[cfg_attr(not(feature = "e2e"), ignore)]
 300async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 301    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 302
 303    // Test a tool call that's likely to complete *before* streaming stops.
 304    let events = thread
 305        .update(cx, |thread, cx| {
 306            thread.add_tool(EchoTool);
 307            thread.send(
 308                UserMessageId::new(),
 309                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 310                cx,
 311            )
 312        })
 313        .unwrap()
 314        .collect()
 315        .await;
 316    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 317
 318    // Test a tool calls that's likely to complete *after* streaming stops.
 319    let events = thread
 320        .update(cx, |thread, cx| {
 321            thread.remove_tool(&EchoTool::name());
 322            thread.add_tool(DelayTool);
 323            thread.send(
 324                UserMessageId::new(),
 325                [
 326                    "Now call the delay tool with 200ms.",
 327                    "When the timer goes off, then you echo the output of the tool.",
 328                ],
 329                cx,
 330            )
 331        })
 332        .unwrap()
 333        .collect()
 334        .await;
 335    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 336    thread.update(cx, |thread, _cx| {
 337        assert!(
 338            thread
 339                .last_message()
 340                .unwrap()
 341                .as_agent_message()
 342                .unwrap()
 343                .content
 344                .iter()
 345                .any(|content| {
 346                    if let AgentMessageContent::Text(text) = content {
 347                        text.contains("Ding")
 348                    } else {
 349                        false
 350                    }
 351                }),
 352            "{}",
 353            thread.to_markdown()
 354        );
 355    });
 356}
 357
 358#[gpui::test]
 359#[cfg_attr(not(feature = "e2e"), ignore)]
 360async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 361    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 362
 363    // Test a tool call that's likely to complete *before* streaming stops.
 364    let mut events = thread
 365        .update(cx, |thread, cx| {
 366            thread.add_tool(WordListTool);
 367            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 368        })
 369        .unwrap();
 370
 371    let mut saw_partial_tool_use = false;
 372    while let Some(event) = events.next().await {
 373        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 374            thread.update(cx, |thread, _cx| {
 375                // Look for a tool use in the thread's last message
 376                let message = thread.last_message().unwrap();
 377                let agent_message = message.as_agent_message().unwrap();
 378                let last_content = agent_message.content.last().unwrap();
 379                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 380                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 381                    if tool_call.status == acp::ToolCallStatus::Pending {
 382                        if !last_tool_use.is_input_complete
 383                            && last_tool_use.input.get("g").is_none()
 384                        {
 385                            saw_partial_tool_use = true;
 386                        }
 387                    } else {
 388                        last_tool_use
 389                            .input
 390                            .get("a")
 391                            .expect("'a' has streamed because input is now complete");
 392                        last_tool_use
 393                            .input
 394                            .get("g")
 395                            .expect("'g' has streamed because input is now complete");
 396                    }
 397                } else {
 398                    panic!("last content should be a tool use");
 399                }
 400            });
 401        }
 402    }
 403
 404    assert!(
 405        saw_partial_tool_use,
 406        "should see at least one partially streamed tool use in the history"
 407    );
 408}
 409
 410#[gpui::test]
 411async fn test_tool_authorization(cx: &mut TestAppContext) {
 412    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 413    let fake_model = model.as_fake();
 414
 415    let mut events = thread
 416        .update(cx, |thread, cx| {
 417            thread.add_tool(ToolRequiringPermission);
 418            thread.send(UserMessageId::new(), ["abc"], cx)
 419        })
 420        .unwrap();
 421    cx.run_until_parked();
 422    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 423        LanguageModelToolUse {
 424            id: "tool_id_1".into(),
 425            name: ToolRequiringPermission::name().into(),
 426            raw_input: "{}".into(),
 427            input: json!({}),
 428            is_input_complete: true,
 429        },
 430    ));
 431    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 432        LanguageModelToolUse {
 433            id: "tool_id_2".into(),
 434            name: ToolRequiringPermission::name().into(),
 435            raw_input: "{}".into(),
 436            input: json!({}),
 437            is_input_complete: true,
 438        },
 439    ));
 440    fake_model.end_last_completion_stream();
 441    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 442    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 443
 444    // Approve the first
 445    tool_call_auth_1
 446        .response
 447        .send(tool_call_auth_1.options[1].id.clone())
 448        .unwrap();
 449    cx.run_until_parked();
 450
 451    // Reject the second
 452    tool_call_auth_2
 453        .response
 454        .send(tool_call_auth_1.options[2].id.clone())
 455        .unwrap();
 456    cx.run_until_parked();
 457
 458    let completion = fake_model.pending_completions().pop().unwrap();
 459    let message = completion.messages.last().unwrap();
 460    assert_eq!(
 461        message.content,
 462        vec![
 463            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 464                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 465                tool_name: ToolRequiringPermission::name().into(),
 466                is_error: false,
 467                content: "Allowed".into(),
 468                output: Some("Allowed".into())
 469            }),
 470            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 471                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 472                tool_name: ToolRequiringPermission::name().into(),
 473                is_error: true,
 474                content: "Permission to run tool denied by user".into(),
 475                output: Some("Permission to run tool denied by user".into())
 476            })
 477        ]
 478    );
 479
 480    // Simulate yet another tool call.
 481    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 482        LanguageModelToolUse {
 483            id: "tool_id_3".into(),
 484            name: ToolRequiringPermission::name().into(),
 485            raw_input: "{}".into(),
 486            input: json!({}),
 487            is_input_complete: true,
 488        },
 489    ));
 490    fake_model.end_last_completion_stream();
 491
 492    // Respond by always allowing tools.
 493    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 494    tool_call_auth_3
 495        .response
 496        .send(tool_call_auth_3.options[0].id.clone())
 497        .unwrap();
 498    cx.run_until_parked();
 499    let completion = fake_model.pending_completions().pop().unwrap();
 500    let message = completion.messages.last().unwrap();
 501    assert_eq!(
 502        message.content,
 503        vec![language_model::MessageContent::ToolResult(
 504            LanguageModelToolResult {
 505                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 506                tool_name: ToolRequiringPermission::name().into(),
 507                is_error: false,
 508                content: "Allowed".into(),
 509                output: Some("Allowed".into())
 510            }
 511        )]
 512    );
 513
 514    // Simulate a final tool call, ensuring we don't trigger authorization.
 515    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 516        LanguageModelToolUse {
 517            id: "tool_id_4".into(),
 518            name: ToolRequiringPermission::name().into(),
 519            raw_input: "{}".into(),
 520            input: json!({}),
 521            is_input_complete: true,
 522        },
 523    ));
 524    fake_model.end_last_completion_stream();
 525    cx.run_until_parked();
 526    let completion = fake_model.pending_completions().pop().unwrap();
 527    let message = completion.messages.last().unwrap();
 528    assert_eq!(
 529        message.content,
 530        vec![language_model::MessageContent::ToolResult(
 531            LanguageModelToolResult {
 532                tool_use_id: "tool_id_4".into(),
 533                tool_name: ToolRequiringPermission::name().into(),
 534                is_error: false,
 535                content: "Allowed".into(),
 536                output: Some("Allowed".into())
 537            }
 538        )]
 539    );
 540}
 541
 542#[gpui::test]
 543async fn test_tool_hallucination(cx: &mut TestAppContext) {
 544    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 545    let fake_model = model.as_fake();
 546
 547    let mut events = thread
 548        .update(cx, |thread, cx| {
 549            thread.send(UserMessageId::new(), ["abc"], cx)
 550        })
 551        .unwrap();
 552    cx.run_until_parked();
 553    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 554        LanguageModelToolUse {
 555            id: "tool_id_1".into(),
 556            name: "nonexistent_tool".into(),
 557            raw_input: "{}".into(),
 558            input: json!({}),
 559            is_input_complete: true,
 560        },
 561    ));
 562    fake_model.end_last_completion_stream();
 563
 564    let tool_call = expect_tool_call(&mut events).await;
 565    assert_eq!(tool_call.title, "nonexistent_tool");
 566    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 567    let update = expect_tool_call_update_fields(&mut events).await;
 568    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 569}
 570
 571#[gpui::test]
 572async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 573    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 574    let fake_model = model.as_fake();
 575
 576    let events = thread
 577        .update(cx, |thread, cx| {
 578            thread.add_tool(EchoTool);
 579            thread.send(UserMessageId::new(), ["abc"], cx)
 580        })
 581        .unwrap();
 582    cx.run_until_parked();
 583    let tool_use = LanguageModelToolUse {
 584        id: "tool_id_1".into(),
 585        name: EchoTool::name().into(),
 586        raw_input: "{}".into(),
 587        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 588        is_input_complete: true,
 589    };
 590    fake_model
 591        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 592    fake_model.end_last_completion_stream();
 593
 594    cx.run_until_parked();
 595    let completion = fake_model.pending_completions().pop().unwrap();
 596    let tool_result = LanguageModelToolResult {
 597        tool_use_id: "tool_id_1".into(),
 598        tool_name: EchoTool::name().into(),
 599        is_error: false,
 600        content: "def".into(),
 601        output: Some("def".into()),
 602    };
 603    assert_eq!(
 604        completion.messages[1..],
 605        vec![
 606            LanguageModelRequestMessage {
 607                role: Role::User,
 608                content: vec!["abc".into()],
 609                cache: false
 610            },
 611            LanguageModelRequestMessage {
 612                role: Role::Assistant,
 613                content: vec![MessageContent::ToolUse(tool_use.clone())],
 614                cache: false
 615            },
 616            LanguageModelRequestMessage {
 617                role: Role::User,
 618                content: vec![MessageContent::ToolResult(tool_result.clone())],
 619                cache: true
 620            },
 621        ]
 622    );
 623
 624    // Simulate reaching tool use limit.
 625    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 626        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 627    ));
 628    fake_model.end_last_completion_stream();
 629    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 630    assert!(
 631        last_event
 632            .unwrap_err()
 633            .is::<language_model::ToolUseLimitReachedError>()
 634    );
 635
 636    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 637    cx.run_until_parked();
 638    let completion = fake_model.pending_completions().pop().unwrap();
 639    assert_eq!(
 640        completion.messages[1..],
 641        vec![
 642            LanguageModelRequestMessage {
 643                role: Role::User,
 644                content: vec!["abc".into()],
 645                cache: false
 646            },
 647            LanguageModelRequestMessage {
 648                role: Role::Assistant,
 649                content: vec![MessageContent::ToolUse(tool_use)],
 650                cache: false
 651            },
 652            LanguageModelRequestMessage {
 653                role: Role::User,
 654                content: vec![MessageContent::ToolResult(tool_result)],
 655                cache: false
 656            },
 657            LanguageModelRequestMessage {
 658                role: Role::User,
 659                content: vec!["Continue where you left off".into()],
 660                cache: true
 661            }
 662        ]
 663    );
 664
 665    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 666    fake_model.end_last_completion_stream();
 667    events.collect::<Vec<_>>().await;
 668    thread.read_with(cx, |thread, _cx| {
 669        assert_eq!(
 670            thread.last_message().unwrap().to_markdown(),
 671            indoc! {"
 672                ## Assistant
 673
 674                Done
 675            "}
 676        )
 677    });
 678}
 679
 680#[gpui::test]
 681async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 682    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 683    let fake_model = model.as_fake();
 684
 685    let events = thread
 686        .update(cx, |thread, cx| {
 687            thread.add_tool(EchoTool);
 688            thread.send(UserMessageId::new(), ["abc"], cx)
 689        })
 690        .unwrap();
 691    cx.run_until_parked();
 692
 693    let tool_use = LanguageModelToolUse {
 694        id: "tool_id_1".into(),
 695        name: EchoTool::name().into(),
 696        raw_input: "{}".into(),
 697        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 698        is_input_complete: true,
 699    };
 700    let tool_result = LanguageModelToolResult {
 701        tool_use_id: "tool_id_1".into(),
 702        tool_name: EchoTool::name().into(),
 703        is_error: false,
 704        content: "def".into(),
 705        output: Some("def".into()),
 706    };
 707    fake_model
 708        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 709    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 710        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 711    ));
 712    fake_model.end_last_completion_stream();
 713    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 714    assert!(
 715        last_event
 716            .unwrap_err()
 717            .is::<language_model::ToolUseLimitReachedError>()
 718    );
 719
 720    thread
 721        .update(cx, |thread, cx| {
 722            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 723        })
 724        .unwrap();
 725    cx.run_until_parked();
 726    let completion = fake_model.pending_completions().pop().unwrap();
 727    assert_eq!(
 728        completion.messages[1..],
 729        vec![
 730            LanguageModelRequestMessage {
 731                role: Role::User,
 732                content: vec!["abc".into()],
 733                cache: false
 734            },
 735            LanguageModelRequestMessage {
 736                role: Role::Assistant,
 737                content: vec![MessageContent::ToolUse(tool_use)],
 738                cache: false
 739            },
 740            LanguageModelRequestMessage {
 741                role: Role::User,
 742                content: vec![MessageContent::ToolResult(tool_result)],
 743                cache: false
 744            },
 745            LanguageModelRequestMessage {
 746                role: Role::User,
 747                content: vec!["ghi".into()],
 748                cache: true
 749            }
 750        ]
 751    );
 752}
 753
 754async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 755    let event = events
 756        .next()
 757        .await
 758        .expect("no tool call authorization event received")
 759        .unwrap();
 760    match event {
 761        ThreadEvent::ToolCall(tool_call) => tool_call,
 762        event => {
 763            panic!("Unexpected event {event:?}");
 764        }
 765    }
 766}
 767
 768async fn expect_tool_call_update_fields(
 769    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 770) -> acp::ToolCallUpdate {
 771    let event = events
 772        .next()
 773        .await
 774        .expect("no tool call authorization event received")
 775        .unwrap();
 776    match event {
 777        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 778        event => {
 779            panic!("Unexpected event {event:?}");
 780        }
 781    }
 782}
 783
 784async fn next_tool_call_authorization(
 785    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 786) -> ToolCallAuthorization {
 787    loop {
 788        let event = events
 789            .next()
 790            .await
 791            .expect("no tool call authorization event received")
 792            .unwrap();
 793        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 794            let permission_kinds = tool_call_authorization
 795                .options
 796                .iter()
 797                .map(|o| o.kind)
 798                .collect::<Vec<_>>();
 799            assert_eq!(
 800                permission_kinds,
 801                vec![
 802                    acp::PermissionOptionKind::AllowAlways,
 803                    acp::PermissionOptionKind::AllowOnce,
 804                    acp::PermissionOptionKind::RejectOnce,
 805                ]
 806            );
 807            return tool_call_authorization;
 808        }
 809    }
 810}
 811
 812#[gpui::test]
 813#[cfg_attr(not(feature = "e2e"), ignore)]
 814async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 815    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 816
 817    // Test concurrent tool calls with different delay times
 818    let events = thread
 819        .update(cx, |thread, cx| {
 820            thread.add_tool(DelayTool);
 821            thread.send(
 822                UserMessageId::new(),
 823                [
 824                    "Call the delay tool twice in the same message.",
 825                    "Once with 100ms. Once with 300ms.",
 826                    "When both timers are complete, describe the outputs.",
 827                ],
 828                cx,
 829            )
 830        })
 831        .unwrap()
 832        .collect()
 833        .await;
 834
 835    let stop_reasons = stop_events(events);
 836    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 837
 838    thread.update(cx, |thread, _cx| {
 839        let last_message = thread.last_message().unwrap();
 840        let agent_message = last_message.as_agent_message().unwrap();
 841        let text = agent_message
 842            .content
 843            .iter()
 844            .filter_map(|content| {
 845                if let AgentMessageContent::Text(text) = content {
 846                    Some(text.as_str())
 847                } else {
 848                    None
 849                }
 850            })
 851            .collect::<String>();
 852
 853        assert!(text.contains("Ding"));
 854    });
 855}
 856
 857#[gpui::test]
 858async fn test_profiles(cx: &mut TestAppContext) {
 859    let ThreadTest {
 860        model, thread, fs, ..
 861    } = setup(cx, TestModel::Fake).await;
 862    let fake_model = model.as_fake();
 863
 864    thread.update(cx, |thread, _cx| {
 865        thread.add_tool(DelayTool);
 866        thread.add_tool(EchoTool);
 867        thread.add_tool(InfiniteTool);
 868    });
 869
 870    // Override profiles and wait for settings to be loaded.
 871    fs.insert_file(
 872        paths::settings_file(),
 873        json!({
 874            "agent": {
 875                "profiles": {
 876                    "test-1": {
 877                        "name": "Test Profile 1",
 878                        "tools": {
 879                            EchoTool::name(): true,
 880                            DelayTool::name(): true,
 881                        }
 882                    },
 883                    "test-2": {
 884                        "name": "Test Profile 2",
 885                        "tools": {
 886                            InfiniteTool::name(): true,
 887                        }
 888                    }
 889                }
 890            }
 891        })
 892        .to_string()
 893        .into_bytes(),
 894    )
 895    .await;
 896    cx.run_until_parked();
 897
 898    // Test that test-1 profile (default) has echo and delay tools
 899    thread
 900        .update(cx, |thread, cx| {
 901            thread.set_profile(AgentProfileId("test-1".into()));
 902            thread.send(UserMessageId::new(), ["test"], cx)
 903        })
 904        .unwrap();
 905    cx.run_until_parked();
 906
 907    let mut pending_completions = fake_model.pending_completions();
 908    assert_eq!(pending_completions.len(), 1);
 909    let completion = pending_completions.pop().unwrap();
 910    let tool_names: Vec<String> = completion
 911        .tools
 912        .iter()
 913        .map(|tool| tool.name.clone())
 914        .collect();
 915    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 916    fake_model.end_last_completion_stream();
 917
 918    // Switch to test-2 profile, and verify that it has only the infinite tool.
 919    thread
 920        .update(cx, |thread, cx| {
 921            thread.set_profile(AgentProfileId("test-2".into()));
 922            thread.send(UserMessageId::new(), ["test2"], cx)
 923        })
 924        .unwrap();
 925    cx.run_until_parked();
 926    let mut pending_completions = fake_model.pending_completions();
 927    assert_eq!(pending_completions.len(), 1);
 928    let completion = pending_completions.pop().unwrap();
 929    let tool_names: Vec<String> = completion
 930        .tools
 931        .iter()
 932        .map(|tool| tool.name.clone())
 933        .collect();
 934    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 935}
 936
 937#[gpui::test]
 938async fn test_mcp_tools(cx: &mut TestAppContext) {
 939    let ThreadTest {
 940        model,
 941        thread,
 942        context_server_store,
 943        fs,
 944        ..
 945    } = setup(cx, TestModel::Fake).await;
 946    let fake_model = model.as_fake();
 947
 948    // Override profiles and wait for settings to be loaded.
 949    fs.insert_file(
 950        paths::settings_file(),
 951        json!({
 952            "agent": {
 953                "always_allow_tool_actions": true,
 954                "profiles": {
 955                    "test": {
 956                        "name": "Test Profile",
 957                        "enable_all_context_servers": true,
 958                        "tools": {
 959                            EchoTool::name(): true,
 960                        }
 961                    },
 962                }
 963            }
 964        })
 965        .to_string()
 966        .into_bytes(),
 967    )
 968    .await;
 969    cx.run_until_parked();
 970    thread.update(cx, |thread, _| {
 971        thread.set_profile(AgentProfileId("test".into()))
 972    });
 973
 974    let mut mcp_tool_calls = setup_context_server(
 975        "test_server",
 976        vec![context_server::types::Tool {
 977            name: "echo".into(),
 978            description: None,
 979            input_schema: serde_json::to_value(
 980                EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
 981            )
 982            .unwrap(),
 983            output_schema: None,
 984            annotations: None,
 985        }],
 986        &context_server_store,
 987        cx,
 988    );
 989
 990    let events = thread.update(cx, |thread, cx| {
 991        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
 992    });
 993    cx.run_until_parked();
 994
 995    // Simulate the model calling the MCP tool.
 996    let completion = fake_model.pending_completions().pop().unwrap();
 997    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
 998    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 999        LanguageModelToolUse {
1000            id: "tool_1".into(),
1001            name: "echo".into(),
1002            raw_input: json!({"text": "test"}).to_string(),
1003            input: json!({"text": "test"}),
1004            is_input_complete: true,
1005        },
1006    ));
1007    fake_model.end_last_completion_stream();
1008    cx.run_until_parked();
1009
1010    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1011    assert_eq!(tool_call_params.name, "echo");
1012    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1013    tool_call_response
1014        .send(context_server::types::CallToolResponse {
1015            content: vec![context_server::types::ToolResponseContent::Text {
1016                text: "test".into(),
1017            }],
1018            is_error: None,
1019            meta: None,
1020            structured_content: None,
1021        })
1022        .unwrap();
1023    cx.run_until_parked();
1024
1025    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1026    fake_model.send_last_completion_stream_text_chunk("Done!");
1027    fake_model.end_last_completion_stream();
1028    events.collect::<Vec<_>>().await;
1029
1030    // Send again after adding the echo tool, ensuring the name collision is resolved.
1031    let events = thread.update(cx, |thread, cx| {
1032        thread.add_tool(EchoTool);
1033        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1034    });
1035    cx.run_until_parked();
1036    let completion = fake_model.pending_completions().pop().unwrap();
1037    assert_eq!(
1038        tool_names_for_completion(&completion),
1039        vec!["echo", "test_server_echo"]
1040    );
1041    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1042        LanguageModelToolUse {
1043            id: "tool_2".into(),
1044            name: "test_server_echo".into(),
1045            raw_input: json!({"text": "mcp"}).to_string(),
1046            input: json!({"text": "mcp"}),
1047            is_input_complete: true,
1048        },
1049    ));
1050    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1051        LanguageModelToolUse {
1052            id: "tool_3".into(),
1053            name: "echo".into(),
1054            raw_input: json!({"text": "native"}).to_string(),
1055            input: json!({"text": "native"}),
1056            is_input_complete: true,
1057        },
1058    ));
1059    fake_model.end_last_completion_stream();
1060    cx.run_until_parked();
1061
1062    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1063    assert_eq!(tool_call_params.name, "echo");
1064    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1065    tool_call_response
1066        .send(context_server::types::CallToolResponse {
1067            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1068            is_error: None,
1069            meta: None,
1070            structured_content: None,
1071        })
1072        .unwrap();
1073    cx.run_until_parked();
1074
1075    // Ensure the tool results were inserted with the correct names.
1076    let completion = fake_model.pending_completions().pop().unwrap();
1077    assert_eq!(
1078        completion.messages.last().unwrap().content,
1079        vec![
1080            MessageContent::ToolResult(LanguageModelToolResult {
1081                tool_use_id: "tool_3".into(),
1082                tool_name: "echo".into(),
1083                is_error: false,
1084                content: "native".into(),
1085                output: Some("native".into()),
1086            },),
1087            MessageContent::ToolResult(LanguageModelToolResult {
1088                tool_use_id: "tool_2".into(),
1089                tool_name: "test_server_echo".into(),
1090                is_error: false,
1091                content: "mcp".into(),
1092                output: Some("mcp".into()),
1093            },),
1094        ]
1095    );
1096    fake_model.end_last_completion_stream();
1097    events.collect::<Vec<_>>().await;
1098}
1099
1100#[gpui::test]
1101async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1102    let ThreadTest {
1103        model,
1104        thread,
1105        context_server_store,
1106        fs,
1107        ..
1108    } = setup(cx, TestModel::Fake).await;
1109    let fake_model = model.as_fake();
1110
1111    // Set up a profile with all tools enabled
1112    fs.insert_file(
1113        paths::settings_file(),
1114        json!({
1115            "agent": {
1116                "profiles": {
1117                    "test": {
1118                        "name": "Test Profile",
1119                        "enable_all_context_servers": true,
1120                        "tools": {
1121                            EchoTool::name(): true,
1122                            DelayTool::name(): true,
1123                            WordListTool::name(): true,
1124                            ToolRequiringPermission::name(): true,
1125                            InfiniteTool::name(): true,
1126                        }
1127                    },
1128                }
1129            }
1130        })
1131        .to_string()
1132        .into_bytes(),
1133    )
1134    .await;
1135    cx.run_until_parked();
1136
1137    thread.update(cx, |thread, _| {
1138        thread.set_profile(AgentProfileId("test".into()));
1139        thread.add_tool(EchoTool);
1140        thread.add_tool(DelayTool);
1141        thread.add_tool(WordListTool);
1142        thread.add_tool(ToolRequiringPermission);
1143        thread.add_tool(InfiniteTool);
1144    });
1145
1146    // Set up multiple context servers with some overlapping tool names
1147    let _server1_calls = setup_context_server(
1148        "xxx",
1149        vec![
1150            context_server::types::Tool {
1151                name: "echo".into(), // Conflicts with native EchoTool
1152                description: None,
1153                input_schema: serde_json::to_value(
1154                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1155                )
1156                .unwrap(),
1157                output_schema: None,
1158                annotations: None,
1159            },
1160            context_server::types::Tool {
1161                name: "unique_tool_1".into(),
1162                description: None,
1163                input_schema: json!({"type": "object", "properties": {}}),
1164                output_schema: None,
1165                annotations: None,
1166            },
1167        ],
1168        &context_server_store,
1169        cx,
1170    );
1171
1172    let _server2_calls = setup_context_server(
1173        "yyy",
1174        vec![
1175            context_server::types::Tool {
1176                name: "echo".into(), // Also conflicts with native EchoTool
1177                description: None,
1178                input_schema: serde_json::to_value(
1179                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1180                )
1181                .unwrap(),
1182                output_schema: None,
1183                annotations: None,
1184            },
1185            context_server::types::Tool {
1186                name: "unique_tool_2".into(),
1187                description: None,
1188                input_schema: json!({"type": "object", "properties": {}}),
1189                output_schema: None,
1190                annotations: None,
1191            },
1192            context_server::types::Tool {
1193                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1194                description: None,
1195                input_schema: json!({"type": "object", "properties": {}}),
1196                output_schema: None,
1197                annotations: None,
1198            },
1199            context_server::types::Tool {
1200                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1201                description: None,
1202                input_schema: json!({"type": "object", "properties": {}}),
1203                output_schema: None,
1204                annotations: None,
1205            },
1206        ],
1207        &context_server_store,
1208        cx,
1209    );
1210    let _server3_calls = setup_context_server(
1211        "zzz",
1212        vec![
1213            context_server::types::Tool {
1214                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1215                description: None,
1216                input_schema: json!({"type": "object", "properties": {}}),
1217                output_schema: None,
1218                annotations: None,
1219            },
1220            context_server::types::Tool {
1221                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1222                description: None,
1223                input_schema: json!({"type": "object", "properties": {}}),
1224                output_schema: None,
1225                annotations: None,
1226            },
1227            context_server::types::Tool {
1228                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1229                description: None,
1230                input_schema: json!({"type": "object", "properties": {}}),
1231                output_schema: None,
1232                annotations: None,
1233            },
1234        ],
1235        &context_server_store,
1236        cx,
1237    );
1238
1239    thread
1240        .update(cx, |thread, cx| {
1241            thread.send(UserMessageId::new(), ["Go"], cx)
1242        })
1243        .unwrap();
1244    cx.run_until_parked();
1245    let completion = fake_model.pending_completions().pop().unwrap();
1246    assert_eq!(
1247        tool_names_for_completion(&completion),
1248        vec![
1249            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1250            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1251            "delay",
1252            "echo",
1253            "infinite",
1254            "tool_requiring_permission",
1255            "unique_tool_1",
1256            "unique_tool_2",
1257            "word_list",
1258            "xxx_echo",
1259            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1260            "yyy_echo",
1261            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1262        ]
1263    );
1264}
1265
1266#[gpui::test]
1267#[cfg_attr(not(feature = "e2e"), ignore)]
1268async fn test_cancellation(cx: &mut TestAppContext) {
1269    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1270
1271    let mut events = thread
1272        .update(cx, |thread, cx| {
1273            thread.add_tool(InfiniteTool);
1274            thread.add_tool(EchoTool);
1275            thread.send(
1276                UserMessageId::new(),
1277                ["Call the echo tool, then call the infinite tool, then explain their output"],
1278                cx,
1279            )
1280        })
1281        .unwrap();
1282
1283    // Wait until both tools are called.
1284    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1285    let mut echo_id = None;
1286    let mut echo_completed = false;
1287    while let Some(event) = events.next().await {
1288        match event.unwrap() {
1289            ThreadEvent::ToolCall(tool_call) => {
1290                assert_eq!(tool_call.title, expected_tools.remove(0));
1291                if tool_call.title == "Echo" {
1292                    echo_id = Some(tool_call.id);
1293                }
1294            }
1295            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1296                acp::ToolCallUpdate {
1297                    id,
1298                    fields:
1299                        acp::ToolCallUpdateFields {
1300                            status: Some(acp::ToolCallStatus::Completed),
1301                            ..
1302                        },
1303                },
1304            )) if Some(&id) == echo_id.as_ref() => {
1305                echo_completed = true;
1306            }
1307            _ => {}
1308        }
1309
1310        if expected_tools.is_empty() && echo_completed {
1311            break;
1312        }
1313    }
1314
1315    // Cancel the current send and ensure that the event stream is closed, even
1316    // if one of the tools is still running.
1317    thread.update(cx, |thread, cx| thread.cancel(cx));
1318    let events = events.collect::<Vec<_>>().await;
1319    let last_event = events.last();
1320    assert!(
1321        matches!(
1322            last_event,
1323            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1324        ),
1325        "unexpected event {last_event:?}"
1326    );
1327
1328    // Ensure we can still send a new message after cancellation.
1329    let events = thread
1330        .update(cx, |thread, cx| {
1331            thread.send(
1332                UserMessageId::new(),
1333                ["Testing: reply with 'Hello' then stop."],
1334                cx,
1335            )
1336        })
1337        .unwrap()
1338        .collect::<Vec<_>>()
1339        .await;
1340    thread.update(cx, |thread, _cx| {
1341        let message = thread.last_message().unwrap();
1342        let agent_message = message.as_agent_message().unwrap();
1343        assert_eq!(
1344            agent_message.content,
1345            vec![AgentMessageContent::Text("Hello".to_string())]
1346        );
1347    });
1348    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1349}
1350
1351#[gpui::test]
1352#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1353async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1354    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1355    let fake_model = model.as_fake();
1356
1357    let events_1 = thread
1358        .update(cx, |thread, cx| {
1359            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1360        })
1361        .unwrap();
1362    cx.run_until_parked();
1363    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1364    cx.run_until_parked();
1365
1366    let events_2 = thread
1367        .update(cx, |thread, cx| {
1368            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1369        })
1370        .unwrap();
1371    cx.run_until_parked();
1372    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1373    fake_model
1374        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1375    fake_model.end_last_completion_stream();
1376
1377    let events_1 = events_1.collect::<Vec<_>>().await;
1378    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1379    let events_2 = events_2.collect::<Vec<_>>().await;
1380    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1381}
1382
1383#[gpui::test]
1384async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1385    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1386    let fake_model = model.as_fake();
1387
1388    let events_1 = thread
1389        .update(cx, |thread, cx| {
1390            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1391        })
1392        .unwrap();
1393    cx.run_until_parked();
1394    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1395    fake_model
1396        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1397    fake_model.end_last_completion_stream();
1398    let events_1 = events_1.collect::<Vec<_>>().await;
1399
1400    let events_2 = thread
1401        .update(cx, |thread, cx| {
1402            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1403        })
1404        .unwrap();
1405    cx.run_until_parked();
1406    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1407    fake_model
1408        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1409    fake_model.end_last_completion_stream();
1410    let events_2 = events_2.collect::<Vec<_>>().await;
1411
1412    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1413    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1414}
1415
1416#[gpui::test]
1417async fn test_refusal(cx: &mut TestAppContext) {
1418    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1419    let fake_model = model.as_fake();
1420
1421    let events = thread
1422        .update(cx, |thread, cx| {
1423            thread.send(UserMessageId::new(), ["Hello"], cx)
1424        })
1425        .unwrap();
1426    cx.run_until_parked();
1427    thread.read_with(cx, |thread, _| {
1428        assert_eq!(
1429            thread.to_markdown(),
1430            indoc! {"
1431                ## User
1432
1433                Hello
1434            "}
1435        );
1436    });
1437
1438    fake_model.send_last_completion_stream_text_chunk("Hey!");
1439    cx.run_until_parked();
1440    thread.read_with(cx, |thread, _| {
1441        assert_eq!(
1442            thread.to_markdown(),
1443            indoc! {"
1444                ## User
1445
1446                Hello
1447
1448                ## Assistant
1449
1450                Hey!
1451            "}
1452        );
1453    });
1454
1455    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1456    fake_model
1457        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1458    let events = events.collect::<Vec<_>>().await;
1459    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1460    thread.read_with(cx, |thread, _| {
1461        assert_eq!(thread.to_markdown(), "");
1462    });
1463}
1464
1465#[gpui::test]
1466async fn test_truncate_first_message(cx: &mut TestAppContext) {
1467    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1468    let fake_model = model.as_fake();
1469
1470    let message_id = UserMessageId::new();
1471    thread
1472        .update(cx, |thread, cx| {
1473            thread.send(message_id.clone(), ["Hello"], cx)
1474        })
1475        .unwrap();
1476    cx.run_until_parked();
1477    thread.read_with(cx, |thread, _| {
1478        assert_eq!(
1479            thread.to_markdown(),
1480            indoc! {"
1481                ## User
1482
1483                Hello
1484            "}
1485        );
1486        assert_eq!(thread.latest_token_usage(), None);
1487    });
1488
1489    fake_model.send_last_completion_stream_text_chunk("Hey!");
1490    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1491        language_model::TokenUsage {
1492            input_tokens: 32_000,
1493            output_tokens: 16_000,
1494            cache_creation_input_tokens: 0,
1495            cache_read_input_tokens: 0,
1496        },
1497    ));
1498    cx.run_until_parked();
1499    thread.read_with(cx, |thread, _| {
1500        assert_eq!(
1501            thread.to_markdown(),
1502            indoc! {"
1503                ## User
1504
1505                Hello
1506
1507                ## Assistant
1508
1509                Hey!
1510            "}
1511        );
1512        assert_eq!(
1513            thread.latest_token_usage(),
1514            Some(acp_thread::TokenUsage {
1515                used_tokens: 32_000 + 16_000,
1516                max_tokens: 1_000_000,
1517            })
1518        );
1519    });
1520
1521    thread
1522        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1523        .unwrap();
1524    cx.run_until_parked();
1525    thread.read_with(cx, |thread, _| {
1526        assert_eq!(thread.to_markdown(), "");
1527        assert_eq!(thread.latest_token_usage(), None);
1528    });
1529
1530    // Ensure we can still send a new message after truncation.
1531    thread
1532        .update(cx, |thread, cx| {
1533            thread.send(UserMessageId::new(), ["Hi"], cx)
1534        })
1535        .unwrap();
1536    thread.update(cx, |thread, _cx| {
1537        assert_eq!(
1538            thread.to_markdown(),
1539            indoc! {"
1540                ## User
1541
1542                Hi
1543            "}
1544        );
1545    });
1546    cx.run_until_parked();
1547    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1548    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1549        language_model::TokenUsage {
1550            input_tokens: 40_000,
1551            output_tokens: 20_000,
1552            cache_creation_input_tokens: 0,
1553            cache_read_input_tokens: 0,
1554        },
1555    ));
1556    cx.run_until_parked();
1557    thread.read_with(cx, |thread, _| {
1558        assert_eq!(
1559            thread.to_markdown(),
1560            indoc! {"
1561                ## User
1562
1563                Hi
1564
1565                ## Assistant
1566
1567                Ahoy!
1568            "}
1569        );
1570
1571        assert_eq!(
1572            thread.latest_token_usage(),
1573            Some(acp_thread::TokenUsage {
1574                used_tokens: 40_000 + 20_000,
1575                max_tokens: 1_000_000,
1576            })
1577        );
1578    });
1579}
1580
1581#[gpui::test]
1582async fn test_truncate_second_message(cx: &mut TestAppContext) {
1583    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1584    let fake_model = model.as_fake();
1585
1586    thread
1587        .update(cx, |thread, cx| {
1588            thread.send(UserMessageId::new(), ["Message 1"], cx)
1589        })
1590        .unwrap();
1591    cx.run_until_parked();
1592    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1593    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1594        language_model::TokenUsage {
1595            input_tokens: 32_000,
1596            output_tokens: 16_000,
1597            cache_creation_input_tokens: 0,
1598            cache_read_input_tokens: 0,
1599        },
1600    ));
1601    fake_model.end_last_completion_stream();
1602    cx.run_until_parked();
1603
1604    let assert_first_message_state = |cx: &mut TestAppContext| {
1605        thread.clone().read_with(cx, |thread, _| {
1606            assert_eq!(
1607                thread.to_markdown(),
1608                indoc! {"
1609                    ## User
1610
1611                    Message 1
1612
1613                    ## Assistant
1614
1615                    Message 1 response
1616                "}
1617            );
1618
1619            assert_eq!(
1620                thread.latest_token_usage(),
1621                Some(acp_thread::TokenUsage {
1622                    used_tokens: 32_000 + 16_000,
1623                    max_tokens: 1_000_000,
1624                })
1625            );
1626        });
1627    };
1628
1629    assert_first_message_state(cx);
1630
1631    let second_message_id = UserMessageId::new();
1632    thread
1633        .update(cx, |thread, cx| {
1634            thread.send(second_message_id.clone(), ["Message 2"], cx)
1635        })
1636        .unwrap();
1637    cx.run_until_parked();
1638
1639    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1640    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1641        language_model::TokenUsage {
1642            input_tokens: 40_000,
1643            output_tokens: 20_000,
1644            cache_creation_input_tokens: 0,
1645            cache_read_input_tokens: 0,
1646        },
1647    ));
1648    fake_model.end_last_completion_stream();
1649    cx.run_until_parked();
1650
1651    thread.read_with(cx, |thread, _| {
1652        assert_eq!(
1653            thread.to_markdown(),
1654            indoc! {"
1655                ## User
1656
1657                Message 1
1658
1659                ## Assistant
1660
1661                Message 1 response
1662
1663                ## User
1664
1665                Message 2
1666
1667                ## Assistant
1668
1669                Message 2 response
1670            "}
1671        );
1672
1673        assert_eq!(
1674            thread.latest_token_usage(),
1675            Some(acp_thread::TokenUsage {
1676                used_tokens: 40_000 + 20_000,
1677                max_tokens: 1_000_000,
1678            })
1679        );
1680    });
1681
1682    thread
1683        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1684        .unwrap();
1685    cx.run_until_parked();
1686
1687    assert_first_message_state(cx);
1688}
1689
1690#[gpui::test]
1691#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1692async fn test_title_generation(cx: &mut TestAppContext) {
1693    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1694    let fake_model = model.as_fake();
1695
1696    let summary_model = Arc::new(FakeLanguageModel::default());
1697    thread.update(cx, |thread, cx| {
1698        thread.set_summarization_model(Some(summary_model.clone()), cx)
1699    });
1700
1701    let send = thread
1702        .update(cx, |thread, cx| {
1703            thread.send(UserMessageId::new(), ["Hello"], cx)
1704        })
1705        .unwrap();
1706    cx.run_until_parked();
1707
1708    fake_model.send_last_completion_stream_text_chunk("Hey!");
1709    fake_model.end_last_completion_stream();
1710    cx.run_until_parked();
1711    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1712
1713    // Ensure the summary model has been invoked to generate a title.
1714    summary_model.send_last_completion_stream_text_chunk("Hello ");
1715    summary_model.send_last_completion_stream_text_chunk("world\nG");
1716    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1717    summary_model.end_last_completion_stream();
1718    send.collect::<Vec<_>>().await;
1719    cx.run_until_parked();
1720    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1721
1722    // Send another message, ensuring no title is generated this time.
1723    let send = thread
1724        .update(cx, |thread, cx| {
1725            thread.send(UserMessageId::new(), ["Hello again"], cx)
1726        })
1727        .unwrap();
1728    cx.run_until_parked();
1729    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1730    fake_model.end_last_completion_stream();
1731    cx.run_until_parked();
1732    assert_eq!(summary_model.pending_completions(), Vec::new());
1733    send.collect::<Vec<_>>().await;
1734    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1735}
1736
1737#[gpui::test]
1738async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1739    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1740    let fake_model = model.as_fake();
1741
1742    let _events = thread
1743        .update(cx, |thread, cx| {
1744            thread.add_tool(ToolRequiringPermission);
1745            thread.add_tool(EchoTool);
1746            thread.send(UserMessageId::new(), ["Hey!"], cx)
1747        })
1748        .unwrap();
1749    cx.run_until_parked();
1750
1751    let permission_tool_use = LanguageModelToolUse {
1752        id: "tool_id_1".into(),
1753        name: ToolRequiringPermission::name().into(),
1754        raw_input: "{}".into(),
1755        input: json!({}),
1756        is_input_complete: true,
1757    };
1758    let echo_tool_use = LanguageModelToolUse {
1759        id: "tool_id_2".into(),
1760        name: EchoTool::name().into(),
1761        raw_input: json!({"text": "test"}).to_string(),
1762        input: json!({"text": "test"}),
1763        is_input_complete: true,
1764    };
1765    fake_model.send_last_completion_stream_text_chunk("Hi!");
1766    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1767        permission_tool_use,
1768    ));
1769    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1770        echo_tool_use.clone(),
1771    ));
1772    fake_model.end_last_completion_stream();
1773    cx.run_until_parked();
1774
1775    // Ensure pending tools are skipped when building a request.
1776    let request = thread
1777        .read_with(cx, |thread, cx| {
1778            thread.build_completion_request(CompletionIntent::EditFile, cx)
1779        })
1780        .unwrap();
1781    assert_eq!(
1782        request.messages[1..],
1783        vec![
1784            LanguageModelRequestMessage {
1785                role: Role::User,
1786                content: vec!["Hey!".into()],
1787                cache: true
1788            },
1789            LanguageModelRequestMessage {
1790                role: Role::Assistant,
1791                content: vec![
1792                    MessageContent::Text("Hi!".into()),
1793                    MessageContent::ToolUse(echo_tool_use.clone())
1794                ],
1795                cache: false
1796            },
1797            LanguageModelRequestMessage {
1798                role: Role::User,
1799                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1800                    tool_use_id: echo_tool_use.id.clone(),
1801                    tool_name: echo_tool_use.name,
1802                    is_error: false,
1803                    content: "test".into(),
1804                    output: Some("test".into())
1805                })],
1806                cache: false
1807            },
1808        ],
1809    );
1810}
1811
1812#[gpui::test]
1813async fn test_agent_connection(cx: &mut TestAppContext) {
1814    cx.update(settings::init);
1815    let templates = Templates::new();
1816
1817    // Initialize language model system with test provider
1818    cx.update(|cx| {
1819        gpui_tokio::init(cx);
1820        client::init_settings(cx);
1821
1822        let http_client = FakeHttpClient::with_404_response();
1823        let clock = Arc::new(clock::FakeSystemClock::new());
1824        let client = Client::new(clock, http_client, cx);
1825        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1826        language_model::init(client.clone(), cx);
1827        language_models::init(user_store, client.clone(), cx);
1828        Project::init_settings(cx);
1829        LanguageModelRegistry::test(cx);
1830        agent_settings::init(cx);
1831    });
1832    cx.executor().forbid_parking();
1833
1834    // Create a project for new_thread
1835    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1836    fake_fs.insert_tree(path!("/test"), json!({})).await;
1837    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1838    let cwd = Path::new("/test");
1839    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1840    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1841
1842    // Create agent and connection
1843    let agent = NativeAgent::new(
1844        project.clone(),
1845        history_store,
1846        templates.clone(),
1847        None,
1848        fake_fs.clone(),
1849        &mut cx.to_async(),
1850    )
1851    .await
1852    .unwrap();
1853    let connection = NativeAgentConnection(agent.clone());
1854
1855    // Test model_selector returns Some
1856    let selector_opt = connection.model_selector();
1857    assert!(
1858        selector_opt.is_some(),
1859        "agent2 should always support ModelSelector"
1860    );
1861    let selector = selector_opt.unwrap();
1862
1863    // Test list_models
1864    let listed_models = cx
1865        .update(|cx| selector.list_models(cx))
1866        .await
1867        .expect("list_models should succeed");
1868    let AgentModelList::Grouped(listed_models) = listed_models else {
1869        panic!("Unexpected model list type");
1870    };
1871    assert!(!listed_models.is_empty(), "should have at least one model");
1872    assert_eq!(
1873        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1874        "fake/fake"
1875    );
1876
1877    // Create a thread using new_thread
1878    let connection_rc = Rc::new(connection.clone());
1879    let acp_thread = cx
1880        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1881        .await
1882        .expect("new_thread should succeed");
1883
1884    // Get the session_id from the AcpThread
1885    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1886
1887    // Test selected_model returns the default
1888    let model = cx
1889        .update(|cx| selector.selected_model(&session_id, cx))
1890        .await
1891        .expect("selected_model should succeed");
1892    let model = cx
1893        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1894        .unwrap();
1895    let model = model.as_fake();
1896    assert_eq!(model.id().0, "fake", "should return default model");
1897
1898    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1899    cx.run_until_parked();
1900    model.send_last_completion_stream_text_chunk("def");
1901    cx.run_until_parked();
1902    acp_thread.read_with(cx, |thread, cx| {
1903        assert_eq!(
1904            thread.to_markdown(cx),
1905            indoc! {"
1906                ## User
1907
1908                abc
1909
1910                ## Assistant
1911
1912                def
1913
1914            "}
1915        )
1916    });
1917
1918    // Test cancel
1919    cx.update(|cx| connection.cancel(&session_id, cx));
1920    request.await.expect("prompt should fail gracefully");
1921
1922    // Ensure that dropping the ACP thread causes the native thread to be
1923    // dropped as well.
1924    cx.update(|_| drop(acp_thread));
1925    let result = cx
1926        .update(|cx| {
1927            connection.prompt(
1928                Some(acp_thread::UserMessageId::new()),
1929                acp::PromptRequest {
1930                    session_id: session_id.clone(),
1931                    prompt: vec!["ghi".into()],
1932                },
1933                cx,
1934            )
1935        })
1936        .await;
1937    assert_eq!(
1938        result.as_ref().unwrap_err().to_string(),
1939        "Session not found",
1940        "unexpected result: {:?}",
1941        result
1942    );
1943}
1944
1945#[gpui::test]
1946async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1947    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1948    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1949    let fake_model = model.as_fake();
1950
1951    let mut events = thread
1952        .update(cx, |thread, cx| {
1953            thread.send(UserMessageId::new(), ["Think"], cx)
1954        })
1955        .unwrap();
1956    cx.run_until_parked();
1957
1958    // Simulate streaming partial input.
1959    let input = json!({});
1960    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1961        LanguageModelToolUse {
1962            id: "1".into(),
1963            name: ThinkingTool::name().into(),
1964            raw_input: input.to_string(),
1965            input,
1966            is_input_complete: false,
1967        },
1968    ));
1969
1970    // Input streaming completed
1971    let input = json!({ "content": "Thinking hard!" });
1972    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1973        LanguageModelToolUse {
1974            id: "1".into(),
1975            name: "thinking".into(),
1976            raw_input: input.to_string(),
1977            input,
1978            is_input_complete: true,
1979        },
1980    ));
1981    fake_model.end_last_completion_stream();
1982    cx.run_until_parked();
1983
1984    let tool_call = expect_tool_call(&mut events).await;
1985    assert_eq!(
1986        tool_call,
1987        acp::ToolCall {
1988            id: acp::ToolCallId("1".into()),
1989            title: "Thinking".into(),
1990            kind: acp::ToolKind::Think,
1991            status: acp::ToolCallStatus::Pending,
1992            content: vec![],
1993            locations: vec![],
1994            raw_input: Some(json!({})),
1995            raw_output: None,
1996        }
1997    );
1998    let update = expect_tool_call_update_fields(&mut events).await;
1999    assert_eq!(
2000        update,
2001        acp::ToolCallUpdate {
2002            id: acp::ToolCallId("1".into()),
2003            fields: acp::ToolCallUpdateFields {
2004                title: Some("Thinking".into()),
2005                kind: Some(acp::ToolKind::Think),
2006                raw_input: Some(json!({ "content": "Thinking hard!" })),
2007                ..Default::default()
2008            },
2009        }
2010    );
2011    let update = expect_tool_call_update_fields(&mut events).await;
2012    assert_eq!(
2013        update,
2014        acp::ToolCallUpdate {
2015            id: acp::ToolCallId("1".into()),
2016            fields: acp::ToolCallUpdateFields {
2017                status: Some(acp::ToolCallStatus::InProgress),
2018                ..Default::default()
2019            },
2020        }
2021    );
2022    let update = expect_tool_call_update_fields(&mut events).await;
2023    assert_eq!(
2024        update,
2025        acp::ToolCallUpdate {
2026            id: acp::ToolCallId("1".into()),
2027            fields: acp::ToolCallUpdateFields {
2028                content: Some(vec!["Thinking hard!".into()]),
2029                ..Default::default()
2030            },
2031        }
2032    );
2033    let update = expect_tool_call_update_fields(&mut events).await;
2034    assert_eq!(
2035        update,
2036        acp::ToolCallUpdate {
2037            id: acp::ToolCallId("1".into()),
2038            fields: acp::ToolCallUpdateFields {
2039                status: Some(acp::ToolCallStatus::Completed),
2040                raw_output: Some("Finished thinking.".into()),
2041                ..Default::default()
2042            },
2043        }
2044    );
2045}
2046
2047#[gpui::test]
2048async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2049    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2050    let fake_model = model.as_fake();
2051
2052    let mut events = thread
2053        .update(cx, |thread, cx| {
2054            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2055            thread.send(UserMessageId::new(), ["Hello!"], cx)
2056        })
2057        .unwrap();
2058    cx.run_until_parked();
2059
2060    fake_model.send_last_completion_stream_text_chunk("Hey!");
2061    fake_model.end_last_completion_stream();
2062
2063    let mut retry_events = Vec::new();
2064    while let Some(Ok(event)) = events.next().await {
2065        match event {
2066            ThreadEvent::Retry(retry_status) => {
2067                retry_events.push(retry_status);
2068            }
2069            ThreadEvent::Stop(..) => break,
2070            _ => {}
2071        }
2072    }
2073
2074    assert_eq!(retry_events.len(), 0);
2075    thread.read_with(cx, |thread, _cx| {
2076        assert_eq!(
2077            thread.to_markdown(),
2078            indoc! {"
2079                ## User
2080
2081                Hello!
2082
2083                ## Assistant
2084
2085                Hey!
2086            "}
2087        )
2088    });
2089}
2090
2091#[gpui::test]
2092async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2093    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2094    let fake_model = model.as_fake();
2095
2096    let mut events = thread
2097        .update(cx, |thread, cx| {
2098            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2099            thread.send(UserMessageId::new(), ["Hello!"], cx)
2100        })
2101        .unwrap();
2102    cx.run_until_parked();
2103
2104    fake_model.send_last_completion_stream_text_chunk("Hey,");
2105    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2106        provider: LanguageModelProviderName::new("Anthropic"),
2107        retry_after: Some(Duration::from_secs(3)),
2108    });
2109    fake_model.end_last_completion_stream();
2110
2111    cx.executor().advance_clock(Duration::from_secs(3));
2112    cx.run_until_parked();
2113
2114    fake_model.send_last_completion_stream_text_chunk("there!");
2115    fake_model.end_last_completion_stream();
2116    cx.run_until_parked();
2117
2118    let mut retry_events = Vec::new();
2119    while let Some(Ok(event)) = events.next().await {
2120        match event {
2121            ThreadEvent::Retry(retry_status) => {
2122                retry_events.push(retry_status);
2123            }
2124            ThreadEvent::Stop(..) => break,
2125            _ => {}
2126        }
2127    }
2128
2129    assert_eq!(retry_events.len(), 1);
2130    assert!(matches!(
2131        retry_events[0],
2132        acp_thread::RetryStatus { attempt: 1, .. }
2133    ));
2134    thread.read_with(cx, |thread, _cx| {
2135        assert_eq!(
2136            thread.to_markdown(),
2137            indoc! {"
2138                ## User
2139
2140                Hello!
2141
2142                ## Assistant
2143
2144                Hey,
2145
2146                [resume]
2147
2148                ## Assistant
2149
2150                there!
2151            "}
2152        )
2153    });
2154}
2155
2156#[gpui::test]
2157async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2158    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2159    let fake_model = model.as_fake();
2160
2161    let events = thread
2162        .update(cx, |thread, cx| {
2163            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2164            thread.add_tool(EchoTool);
2165            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2166        })
2167        .unwrap();
2168    cx.run_until_parked();
2169
2170    let tool_use_1 = LanguageModelToolUse {
2171        id: "tool_1".into(),
2172        name: EchoTool::name().into(),
2173        raw_input: json!({"text": "test"}).to_string(),
2174        input: json!({"text": "test"}),
2175        is_input_complete: true,
2176    };
2177    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2178        tool_use_1.clone(),
2179    ));
2180    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2181        provider: LanguageModelProviderName::new("Anthropic"),
2182        retry_after: Some(Duration::from_secs(3)),
2183    });
2184    fake_model.end_last_completion_stream();
2185
2186    cx.executor().advance_clock(Duration::from_secs(3));
2187    let completion = fake_model.pending_completions().pop().unwrap();
2188    assert_eq!(
2189        completion.messages[1..],
2190        vec![
2191            LanguageModelRequestMessage {
2192                role: Role::User,
2193                content: vec!["Call the echo tool!".into()],
2194                cache: false
2195            },
2196            LanguageModelRequestMessage {
2197                role: Role::Assistant,
2198                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2199                cache: false
2200            },
2201            LanguageModelRequestMessage {
2202                role: Role::User,
2203                content: vec![language_model::MessageContent::ToolResult(
2204                    LanguageModelToolResult {
2205                        tool_use_id: tool_use_1.id.clone(),
2206                        tool_name: tool_use_1.name.clone(),
2207                        is_error: false,
2208                        content: "test".into(),
2209                        output: Some("test".into())
2210                    }
2211                )],
2212                cache: true
2213            },
2214        ]
2215    );
2216
2217    fake_model.send_last_completion_stream_text_chunk("Done");
2218    fake_model.end_last_completion_stream();
2219    cx.run_until_parked();
2220    events.collect::<Vec<_>>().await;
2221    thread.read_with(cx, |thread, _cx| {
2222        assert_eq!(
2223            thread.last_message(),
2224            Some(Message::Agent(AgentMessage {
2225                content: vec![AgentMessageContent::Text("Done".into())],
2226                tool_results: IndexMap::default()
2227            }))
2228        );
2229    })
2230}
2231
2232#[gpui::test]
2233async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2234    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2235    let fake_model = model.as_fake();
2236
2237    let mut events = thread
2238        .update(cx, |thread, cx| {
2239            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2240            thread.send(UserMessageId::new(), ["Hello!"], cx)
2241        })
2242        .unwrap();
2243    cx.run_until_parked();
2244
2245    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2246        fake_model.send_last_completion_stream_error(
2247            LanguageModelCompletionError::ServerOverloaded {
2248                provider: LanguageModelProviderName::new("Anthropic"),
2249                retry_after: Some(Duration::from_secs(3)),
2250            },
2251        );
2252        fake_model.end_last_completion_stream();
2253        cx.executor().advance_clock(Duration::from_secs(3));
2254        cx.run_until_parked();
2255    }
2256
2257    let mut errors = Vec::new();
2258    let mut retry_events = Vec::new();
2259    while let Some(event) = events.next().await {
2260        match event {
2261            Ok(ThreadEvent::Retry(retry_status)) => {
2262                retry_events.push(retry_status);
2263            }
2264            Ok(ThreadEvent::Stop(..)) => break,
2265            Err(error) => errors.push(error),
2266            _ => {}
2267        }
2268    }
2269
2270    assert_eq!(
2271        retry_events.len(),
2272        crate::thread::MAX_RETRY_ATTEMPTS as usize
2273    );
2274    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2275        assert_eq!(retry_events[i].attempt, i + 1);
2276    }
2277    assert_eq!(errors.len(), 1);
2278    let error = errors[0]
2279        .downcast_ref::<LanguageModelCompletionError>()
2280        .unwrap();
2281    assert!(matches!(
2282        error,
2283        LanguageModelCompletionError::ServerOverloaded { .. }
2284    ));
2285}
2286
2287/// Filters out the stop events for asserting against in tests
2288fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2289    result_events
2290        .into_iter()
2291        .filter_map(|event| match event.unwrap() {
2292            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2293            _ => None,
2294        })
2295        .collect()
2296}
2297
2298struct ThreadTest {
2299    model: Arc<dyn LanguageModel>,
2300    thread: Entity<Thread>,
2301    project_context: Entity<ProjectContext>,
2302    context_server_store: Entity<ContextServerStore>,
2303    fs: Arc<FakeFs>,
2304}
2305
2306enum TestModel {
2307    Sonnet4,
2308    Fake,
2309}
2310
2311impl TestModel {
2312    fn id(&self) -> LanguageModelId {
2313        match self {
2314            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2315            TestModel::Fake => unreachable!(),
2316        }
2317    }
2318}
2319
2320async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2321    cx.executor().allow_parking();
2322
2323    let fs = FakeFs::new(cx.background_executor.clone());
2324    fs.create_dir(paths::settings_file().parent().unwrap())
2325        .await
2326        .unwrap();
2327    fs.insert_file(
2328        paths::settings_file(),
2329        json!({
2330            "agent": {
2331                "default_profile": "test-profile",
2332                "profiles": {
2333                    "test-profile": {
2334                        "name": "Test Profile",
2335                        "tools": {
2336                            EchoTool::name(): true,
2337                            DelayTool::name(): true,
2338                            WordListTool::name(): true,
2339                            ToolRequiringPermission::name(): true,
2340                            InfiniteTool::name(): true,
2341                            ThinkingTool::name(): true,
2342                        }
2343                    }
2344                }
2345            }
2346        })
2347        .to_string()
2348        .into_bytes(),
2349    )
2350    .await;
2351
2352    cx.update(|cx| {
2353        settings::init(cx);
2354        Project::init_settings(cx);
2355        agent_settings::init(cx);
2356        gpui_tokio::init(cx);
2357        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2358        cx.set_http_client(Arc::new(http_client));
2359
2360        client::init_settings(cx);
2361        let client = Client::production(cx);
2362        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2363        language_model::init(client.clone(), cx);
2364        language_models::init(user_store, client.clone(), cx);
2365
2366        watch_settings(fs.clone(), cx);
2367    });
2368
2369    let templates = Templates::new();
2370
2371    fs.insert_tree(path!("/test"), json!({})).await;
2372    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2373
2374    let model = cx
2375        .update(|cx| {
2376            if let TestModel::Fake = model {
2377                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2378            } else {
2379                let model_id = model.id();
2380                let models = LanguageModelRegistry::read_global(cx);
2381                let model = models
2382                    .available_models(cx)
2383                    .find(|model| model.id() == model_id)
2384                    .unwrap();
2385
2386                let provider = models.provider(&model.provider_id()).unwrap();
2387                let authenticated = provider.authenticate(cx);
2388
2389                cx.spawn(async move |_cx| {
2390                    authenticated.await.unwrap();
2391                    model
2392                })
2393            }
2394        })
2395        .await;
2396
2397    let project_context = cx.new(|_cx| ProjectContext::default());
2398    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2399    let context_server_registry =
2400        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2401    let thread = cx.new(|cx| {
2402        Thread::new(
2403            project,
2404            project_context.clone(),
2405            context_server_registry,
2406            templates,
2407            Some(model.clone()),
2408            cx,
2409        )
2410    });
2411    ThreadTest {
2412        model,
2413        thread,
2414        project_context,
2415        context_server_store,
2416        fs,
2417    }
2418}
2419
2420#[cfg(test)]
2421#[ctor::ctor]
2422fn init_logger() {
2423    if std::env::var("RUST_LOG").is_ok() {
2424        env_logger::init();
2425    }
2426}
2427
2428fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2429    let fs = fs.clone();
2430    cx.spawn({
2431        async move |cx| {
2432            let mut new_settings_content_rx = settings::watch_config_file(
2433                cx.background_executor(),
2434                fs,
2435                paths::settings_file().clone(),
2436            );
2437
2438            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2439                cx.update(|cx| {
2440                    SettingsStore::update_global(cx, |settings, cx| {
2441                        settings.set_user_settings(&new_settings_content, cx)
2442                    })
2443                })
2444                .ok();
2445            }
2446        }
2447    })
2448    .detach();
2449}
2450
2451fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2452    completion
2453        .tools
2454        .iter()
2455        .map(|tool| tool.name.clone())
2456        .collect()
2457}
2458
2459fn setup_context_server(
2460    name: &'static str,
2461    tools: Vec<context_server::types::Tool>,
2462    context_server_store: &Entity<ContextServerStore>,
2463    cx: &mut TestAppContext,
2464) -> mpsc::UnboundedReceiver<(
2465    context_server::types::CallToolParams,
2466    oneshot::Sender<context_server::types::CallToolResponse>,
2467)> {
2468    cx.update(|cx| {
2469        let mut settings = ProjectSettings::get_global(cx).clone();
2470        settings.context_servers.insert(
2471            name.into(),
2472            project::project_settings::ContextServerSettings::Custom {
2473                enabled: true,
2474                command: ContextServerCommand {
2475                    path: "somebinary".into(),
2476                    args: Vec::new(),
2477                    env: None,
2478                },
2479            },
2480        );
2481        ProjectSettings::override_global(settings, cx);
2482    });
2483
2484    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2485    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2486        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2487            context_server::types::InitializeResponse {
2488                protocol_version: context_server::types::ProtocolVersion(
2489                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2490                ),
2491                server_info: context_server::types::Implementation {
2492                    name: name.into(),
2493                    version: "1.0.0".to_string(),
2494                },
2495                capabilities: context_server::types::ServerCapabilities {
2496                    tools: Some(context_server::types::ToolsCapabilities {
2497                        list_changed: Some(true),
2498                    }),
2499                    ..Default::default()
2500                },
2501                meta: None,
2502            }
2503        })
2504        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2505            let tools = tools.clone();
2506            async move {
2507                context_server::types::ListToolsResponse {
2508                    tools,
2509                    next_cursor: None,
2510                    meta: None,
2511                }
2512            }
2513        })
2514        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2515            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2516            async move {
2517                let (response_tx, response_rx) = oneshot::channel();
2518                mcp_tool_calls_tx
2519                    .unbounded_send((params, response_tx))
2520                    .unwrap();
2521                response_rx.await.unwrap()
2522            }
2523        });
2524    context_server_store.update(cx, |store, cx| {
2525        store.start_server(
2526            Arc::new(ContextServer::new(
2527                ContextServerId(name.into()),
2528                Arc::new(fake_transport),
2529            )),
2530            cx,
2531        );
2532    });
2533    cx.run_until_parked();
2534    mcp_tool_calls_rx
2535}