mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, new_prompt_id};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(new_prompt_id(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                new_prompt_id(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
 134        .unwrap();
 135    cx.run_until_parked();
 136    let mut pending_completions = fake_model.pending_completions();
 137    assert_eq!(
 138        pending_completions.len(),
 139        1,
 140        "unexpected pending completions: {:?}",
 141        pending_completions
 142    );
 143
 144    let pending_completion = pending_completions.pop().unwrap();
 145    assert_eq!(pending_completion.messages[0].role, Role::System);
 146
 147    let system_message = &pending_completion.messages[0];
 148    let system_prompt = system_message.content[0].to_str().unwrap();
 149    assert!(
 150        system_prompt.contains("test-shell"),
 151        "unexpected system message: {:?}",
 152        system_message
 153    );
 154    assert!(
 155        system_prompt.contains("## Fixing Diagnostics"),
 156        "unexpected system message: {:?}",
 157        system_message
 158    );
 159}
 160
 161#[gpui::test]
 162async fn test_prompt_caching(cx: &mut TestAppContext) {
 163    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 164    let fake_model = model.as_fake();
 165
 166    // Send initial user message and verify it's cached
 167    thread
 168        .update(cx, |thread, cx| {
 169            thread.send(new_prompt_id(), ["Message 1"], cx)
 170        })
 171        .unwrap();
 172    cx.run_until_parked();
 173
 174    let completion = fake_model.pending_completions().pop().unwrap();
 175    assert_eq!(
 176        completion.messages[1..],
 177        vec![LanguageModelRequestMessage {
 178            role: Role::User,
 179            content: vec!["Message 1".into()],
 180            cache: true
 181        }]
 182    );
 183    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 184        "Response to Message 1".into(),
 185    ));
 186    fake_model.end_last_completion_stream();
 187    cx.run_until_parked();
 188
 189    // Send another user message and verify only the latest is cached
 190    thread
 191        .update(cx, |thread, cx| {
 192            thread.send(new_prompt_id(), ["Message 2"], cx)
 193        })
 194        .unwrap();
 195    cx.run_until_parked();
 196
 197    let completion = fake_model.pending_completions().pop().unwrap();
 198    assert_eq!(
 199        completion.messages[1..],
 200        vec![
 201            LanguageModelRequestMessage {
 202                role: Role::User,
 203                content: vec!["Message 1".into()],
 204                cache: false
 205            },
 206            LanguageModelRequestMessage {
 207                role: Role::Assistant,
 208                content: vec!["Response to Message 1".into()],
 209                cache: false
 210            },
 211            LanguageModelRequestMessage {
 212                role: Role::User,
 213                content: vec!["Message 2".into()],
 214                cache: true
 215            }
 216        ]
 217    );
 218    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 219        "Response to Message 2".into(),
 220    ));
 221    fake_model.end_last_completion_stream();
 222    cx.run_until_parked();
 223
 224    // Simulate a tool call and verify that the latest tool result is cached
 225    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 226    thread
 227        .update(cx, |thread, cx| {
 228            thread.send(new_prompt_id(), ["Use the echo tool"], cx)
 229        })
 230        .unwrap();
 231    cx.run_until_parked();
 232
 233    let tool_use = LanguageModelToolUse {
 234        id: "tool_1".into(),
 235        name: EchoTool::name().into(),
 236        raw_input: json!({"text": "test"}).to_string(),
 237        input: json!({"text": "test"}),
 238        is_input_complete: true,
 239    };
 240    fake_model
 241        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 242    fake_model.end_last_completion_stream();
 243    cx.run_until_parked();
 244
 245    let completion = fake_model.pending_completions().pop().unwrap();
 246    let tool_result = LanguageModelToolResult {
 247        tool_use_id: "tool_1".into(),
 248        tool_name: EchoTool::name().into(),
 249        is_error: false,
 250        content: "test".into(),
 251        output: Some("test".into()),
 252    };
 253    assert_eq!(
 254        completion.messages[1..],
 255        vec![
 256            LanguageModelRequestMessage {
 257                role: Role::User,
 258                content: vec!["Message 1".into()],
 259                cache: false
 260            },
 261            LanguageModelRequestMessage {
 262                role: Role::Assistant,
 263                content: vec!["Response to Message 1".into()],
 264                cache: false
 265            },
 266            LanguageModelRequestMessage {
 267                role: Role::User,
 268                content: vec!["Message 2".into()],
 269                cache: false
 270            },
 271            LanguageModelRequestMessage {
 272                role: Role::Assistant,
 273                content: vec!["Response to Message 2".into()],
 274                cache: false
 275            },
 276            LanguageModelRequestMessage {
 277                role: Role::User,
 278                content: vec!["Use the echo tool".into()],
 279                cache: false
 280            },
 281            LanguageModelRequestMessage {
 282                role: Role::Assistant,
 283                content: vec![MessageContent::ToolUse(tool_use)],
 284                cache: false
 285            },
 286            LanguageModelRequestMessage {
 287                role: Role::User,
 288                content: vec![MessageContent::ToolResult(tool_result)],
 289                cache: true
 290            }
 291        ]
 292    );
 293}
 294
 295#[gpui::test]
 296#[cfg_attr(not(feature = "e2e"), ignore)]
 297async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 298    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 299
 300    // Test a tool call that's likely to complete *before* streaming stops.
 301    let events = thread
 302        .update(cx, |thread, cx| {
 303            thread.add_tool(EchoTool);
 304            thread.send(
 305                new_prompt_id(),
 306                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 307                cx,
 308            )
 309        })
 310        .unwrap()
 311        .collect()
 312        .await;
 313    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 314
 315    // Test a tool calls that's likely to complete *after* streaming stops.
 316    let events = thread
 317        .update(cx, |thread, cx| {
 318            thread.remove_tool(&EchoTool::name());
 319            thread.add_tool(DelayTool);
 320            thread.send(
 321                new_prompt_id(),
 322                [
 323                    "Now call the delay tool with 200ms.",
 324                    "When the timer goes off, then you echo the output of the tool.",
 325                ],
 326                cx,
 327            )
 328        })
 329        .unwrap()
 330        .collect()
 331        .await;
 332    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 333    thread.update(cx, |thread, _cx| {
 334        assert!(
 335            thread
 336                .last_message()
 337                .unwrap()
 338                .as_agent_message()
 339                .unwrap()
 340                .content
 341                .iter()
 342                .any(|content| {
 343                    if let AgentMessageContent::Text(text) = content {
 344                        text.contains("Ding")
 345                    } else {
 346                        false
 347                    }
 348                }),
 349            "{}",
 350            thread.to_markdown()
 351        );
 352    });
 353}
 354
 355#[gpui::test]
 356#[cfg_attr(not(feature = "e2e"), ignore)]
 357async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 358    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 359
 360    // Test a tool call that's likely to complete *before* streaming stops.
 361    let mut events = thread
 362        .update(cx, |thread, cx| {
 363            thread.add_tool(WordListTool);
 364            thread.send(new_prompt_id(), ["Test the word_list tool."], cx)
 365        })
 366        .unwrap();
 367
 368    let mut saw_partial_tool_use = false;
 369    while let Some(event) = events.next().await {
 370        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 371            thread.update(cx, |thread, _cx| {
 372                // Look for a tool use in the thread's last message
 373                let message = thread.last_message().unwrap();
 374                let agent_message = message.as_agent_message().unwrap();
 375                let last_content = agent_message.content.last().unwrap();
 376                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 377                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 378                    if tool_call.status == acp::ToolCallStatus::Pending {
 379                        if !last_tool_use.is_input_complete
 380                            && last_tool_use.input.get("g").is_none()
 381                        {
 382                            saw_partial_tool_use = true;
 383                        }
 384                    } else {
 385                        last_tool_use
 386                            .input
 387                            .get("a")
 388                            .expect("'a' has streamed because input is now complete");
 389                        last_tool_use
 390                            .input
 391                            .get("g")
 392                            .expect("'g' has streamed because input is now complete");
 393                    }
 394                } else {
 395                    panic!("last content should be a tool use");
 396                }
 397            });
 398        }
 399    }
 400
 401    assert!(
 402        saw_partial_tool_use,
 403        "should see at least one partially streamed tool use in the history"
 404    );
 405}
 406
 407#[gpui::test]
 408async fn test_tool_authorization(cx: &mut TestAppContext) {
 409    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 410    let fake_model = model.as_fake();
 411
 412    let mut events = thread
 413        .update(cx, |thread, cx| {
 414            thread.add_tool(ToolRequiringPermission);
 415            thread.send(new_prompt_id(), ["abc"], cx)
 416        })
 417        .unwrap();
 418    cx.run_until_parked();
 419    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 420        LanguageModelToolUse {
 421            id: "tool_id_1".into(),
 422            name: ToolRequiringPermission::name().into(),
 423            raw_input: "{}".into(),
 424            input: json!({}),
 425            is_input_complete: true,
 426        },
 427    ));
 428    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 429        LanguageModelToolUse {
 430            id: "tool_id_2".into(),
 431            name: ToolRequiringPermission::name().into(),
 432            raw_input: "{}".into(),
 433            input: json!({}),
 434            is_input_complete: true,
 435        },
 436    ));
 437    fake_model.end_last_completion_stream();
 438    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 439    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 440
 441    // Approve the first
 442    tool_call_auth_1
 443        .response
 444        .send(tool_call_auth_1.options[1].id.clone())
 445        .unwrap();
 446    cx.run_until_parked();
 447
 448    // Reject the second
 449    tool_call_auth_2
 450        .response
 451        .send(tool_call_auth_1.options[2].id.clone())
 452        .unwrap();
 453    cx.run_until_parked();
 454
 455    let completion = fake_model.pending_completions().pop().unwrap();
 456    let message = completion.messages.last().unwrap();
 457    assert_eq!(
 458        message.content,
 459        vec![
 460            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 461                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 462                tool_name: ToolRequiringPermission::name().into(),
 463                is_error: false,
 464                content: "Allowed".into(),
 465                output: Some("Allowed".into())
 466            }),
 467            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 468                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 469                tool_name: ToolRequiringPermission::name().into(),
 470                is_error: true,
 471                content: "Permission to run tool denied by user".into(),
 472                output: None
 473            })
 474        ]
 475    );
 476
 477    // Simulate yet another tool call.
 478    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 479        LanguageModelToolUse {
 480            id: "tool_id_3".into(),
 481            name: ToolRequiringPermission::name().into(),
 482            raw_input: "{}".into(),
 483            input: json!({}),
 484            is_input_complete: true,
 485        },
 486    ));
 487    fake_model.end_last_completion_stream();
 488
 489    // Respond by always allowing tools.
 490    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 491    tool_call_auth_3
 492        .response
 493        .send(tool_call_auth_3.options[0].id.clone())
 494        .unwrap();
 495    cx.run_until_parked();
 496    let completion = fake_model.pending_completions().pop().unwrap();
 497    let message = completion.messages.last().unwrap();
 498    assert_eq!(
 499        message.content,
 500        vec![language_model::MessageContent::ToolResult(
 501            LanguageModelToolResult {
 502                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 503                tool_name: ToolRequiringPermission::name().into(),
 504                is_error: false,
 505                content: "Allowed".into(),
 506                output: Some("Allowed".into())
 507            }
 508        )]
 509    );
 510
 511    // Simulate a final tool call, ensuring we don't trigger authorization.
 512    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 513        LanguageModelToolUse {
 514            id: "tool_id_4".into(),
 515            name: ToolRequiringPermission::name().into(),
 516            raw_input: "{}".into(),
 517            input: json!({}),
 518            is_input_complete: true,
 519        },
 520    ));
 521    fake_model.end_last_completion_stream();
 522    cx.run_until_parked();
 523    let completion = fake_model.pending_completions().pop().unwrap();
 524    let message = completion.messages.last().unwrap();
 525    assert_eq!(
 526        message.content,
 527        vec![language_model::MessageContent::ToolResult(
 528            LanguageModelToolResult {
 529                tool_use_id: "tool_id_4".into(),
 530                tool_name: ToolRequiringPermission::name().into(),
 531                is_error: false,
 532                content: "Allowed".into(),
 533                output: Some("Allowed".into())
 534            }
 535        )]
 536    );
 537}
 538
 539#[gpui::test]
 540async fn test_tool_hallucination(cx: &mut TestAppContext) {
 541    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 542    let fake_model = model.as_fake();
 543
 544    let mut events = thread
 545        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
 546        .unwrap();
 547    cx.run_until_parked();
 548    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 549        LanguageModelToolUse {
 550            id: "tool_id_1".into(),
 551            name: "nonexistent_tool".into(),
 552            raw_input: "{}".into(),
 553            input: json!({}),
 554            is_input_complete: true,
 555        },
 556    ));
 557    fake_model.end_last_completion_stream();
 558
 559    let tool_call = expect_tool_call(&mut events).await;
 560    assert_eq!(tool_call.title, "nonexistent_tool");
 561    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 562    let update = expect_tool_call_update_fields(&mut events).await;
 563    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 564}
 565
 566#[gpui::test]
 567async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 568    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 569    let fake_model = model.as_fake();
 570
 571    let events = thread
 572        .update(cx, |thread, cx| {
 573            thread.add_tool(EchoTool);
 574            thread.send(new_prompt_id(), ["abc"], cx)
 575        })
 576        .unwrap();
 577    cx.run_until_parked();
 578    let tool_use = LanguageModelToolUse {
 579        id: "tool_id_1".into(),
 580        name: EchoTool::name().into(),
 581        raw_input: "{}".into(),
 582        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 583        is_input_complete: true,
 584    };
 585    fake_model
 586        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 587    fake_model.end_last_completion_stream();
 588
 589    cx.run_until_parked();
 590    let completion = fake_model.pending_completions().pop().unwrap();
 591    let tool_result = LanguageModelToolResult {
 592        tool_use_id: "tool_id_1".into(),
 593        tool_name: EchoTool::name().into(),
 594        is_error: false,
 595        content: "def".into(),
 596        output: Some("def".into()),
 597    };
 598    assert_eq!(
 599        completion.messages[1..],
 600        vec![
 601            LanguageModelRequestMessage {
 602                role: Role::User,
 603                content: vec!["abc".into()],
 604                cache: false
 605            },
 606            LanguageModelRequestMessage {
 607                role: Role::Assistant,
 608                content: vec![MessageContent::ToolUse(tool_use.clone())],
 609                cache: false
 610            },
 611            LanguageModelRequestMessage {
 612                role: Role::User,
 613                content: vec![MessageContent::ToolResult(tool_result.clone())],
 614                cache: true
 615            },
 616        ]
 617    );
 618
 619    // Simulate reaching tool use limit.
 620    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 621        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 622    ));
 623    fake_model.end_last_completion_stream();
 624    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 625    assert!(
 626        last_event
 627            .unwrap_err()
 628            .is::<language_model::ToolUseLimitReachedError>()
 629    );
 630
 631    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 632    cx.run_until_parked();
 633    let completion = fake_model.pending_completions().pop().unwrap();
 634    assert_eq!(
 635        completion.messages[1..],
 636        vec![
 637            LanguageModelRequestMessage {
 638                role: Role::User,
 639                content: vec!["abc".into()],
 640                cache: false
 641            },
 642            LanguageModelRequestMessage {
 643                role: Role::Assistant,
 644                content: vec![MessageContent::ToolUse(tool_use)],
 645                cache: false
 646            },
 647            LanguageModelRequestMessage {
 648                role: Role::User,
 649                content: vec![MessageContent::ToolResult(tool_result)],
 650                cache: false
 651            },
 652            LanguageModelRequestMessage {
 653                role: Role::User,
 654                content: vec!["Continue where you left off".into()],
 655                cache: true
 656            }
 657        ]
 658    );
 659
 660    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 661    fake_model.end_last_completion_stream();
 662    events.collect::<Vec<_>>().await;
 663    thread.read_with(cx, |thread, _cx| {
 664        assert_eq!(
 665            thread.last_message().unwrap().to_markdown(),
 666            indoc! {"
 667                ## Assistant
 668
 669                Done
 670            "}
 671        )
 672    });
 673}
 674
 675#[gpui::test]
 676async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 677    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 678    let fake_model = model.as_fake();
 679
 680    let events = thread
 681        .update(cx, |thread, cx| {
 682            thread.add_tool(EchoTool);
 683            thread.send(new_prompt_id(), ["abc"], cx)
 684        })
 685        .unwrap();
 686    cx.run_until_parked();
 687
 688    let tool_use = LanguageModelToolUse {
 689        id: "tool_id_1".into(),
 690        name: EchoTool::name().into(),
 691        raw_input: "{}".into(),
 692        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 693        is_input_complete: true,
 694    };
 695    let tool_result = LanguageModelToolResult {
 696        tool_use_id: "tool_id_1".into(),
 697        tool_name: EchoTool::name().into(),
 698        is_error: false,
 699        content: "def".into(),
 700        output: Some("def".into()),
 701    };
 702    fake_model
 703        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 704    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 705        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 706    ));
 707    fake_model.end_last_completion_stream();
 708    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 709    assert!(
 710        last_event
 711            .unwrap_err()
 712            .is::<language_model::ToolUseLimitReachedError>()
 713    );
 714
 715    thread
 716        .update(cx, |thread, cx| {
 717            thread.send(new_prompt_id(), vec!["ghi"], cx)
 718        })
 719        .unwrap();
 720    cx.run_until_parked();
 721    let completion = fake_model.pending_completions().pop().unwrap();
 722    assert_eq!(
 723        completion.messages[1..],
 724        vec![
 725            LanguageModelRequestMessage {
 726                role: Role::User,
 727                content: vec!["abc".into()],
 728                cache: false
 729            },
 730            LanguageModelRequestMessage {
 731                role: Role::Assistant,
 732                content: vec![MessageContent::ToolUse(tool_use)],
 733                cache: false
 734            },
 735            LanguageModelRequestMessage {
 736                role: Role::User,
 737                content: vec![MessageContent::ToolResult(tool_result)],
 738                cache: false
 739            },
 740            LanguageModelRequestMessage {
 741                role: Role::User,
 742                content: vec!["ghi".into()],
 743                cache: true
 744            }
 745        ]
 746    );
 747}
 748
 749async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 750    let event = events
 751        .next()
 752        .await
 753        .expect("no tool call authorization event received")
 754        .unwrap();
 755    match event {
 756        ThreadEvent::ToolCall(tool_call) => tool_call,
 757        event => {
 758            panic!("Unexpected event {event:?}");
 759        }
 760    }
 761}
 762
 763async fn expect_tool_call_update_fields(
 764    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 765) -> acp::ToolCallUpdate {
 766    let event = events
 767        .next()
 768        .await
 769        .expect("no tool call authorization event received")
 770        .unwrap();
 771    match event {
 772        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 773        event => {
 774            panic!("Unexpected event {event:?}");
 775        }
 776    }
 777}
 778
 779async fn next_tool_call_authorization(
 780    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 781) -> ToolCallAuthorization {
 782    loop {
 783        let event = events
 784            .next()
 785            .await
 786            .expect("no tool call authorization event received")
 787            .unwrap();
 788        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 789            let permission_kinds = tool_call_authorization
 790                .options
 791                .iter()
 792                .map(|o| o.kind)
 793                .collect::<Vec<_>>();
 794            assert_eq!(
 795                permission_kinds,
 796                vec![
 797                    acp::PermissionOptionKind::AllowAlways,
 798                    acp::PermissionOptionKind::AllowOnce,
 799                    acp::PermissionOptionKind::RejectOnce,
 800                ]
 801            );
 802            return tool_call_authorization;
 803        }
 804    }
 805}
 806
 807#[gpui::test]
 808#[cfg_attr(not(feature = "e2e"), ignore)]
 809async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 810    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 811
 812    // Test concurrent tool calls with different delay times
 813    let events = thread
 814        .update(cx, |thread, cx| {
 815            thread.add_tool(DelayTool);
 816            thread.send(
 817                new_prompt_id(),
 818                [
 819                    "Call the delay tool twice in the same message.",
 820                    "Once with 100ms. Once with 300ms.",
 821                    "When both timers are complete, describe the outputs.",
 822                ],
 823                cx,
 824            )
 825        })
 826        .unwrap()
 827        .collect()
 828        .await;
 829
 830    let stop_reasons = stop_events(events);
 831    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 832
 833    thread.update(cx, |thread, _cx| {
 834        let last_message = thread.last_message().unwrap();
 835        let agent_message = last_message.as_agent_message().unwrap();
 836        let text = agent_message
 837            .content
 838            .iter()
 839            .filter_map(|content| {
 840                if let AgentMessageContent::Text(text) = content {
 841                    Some(text.as_str())
 842                } else {
 843                    None
 844                }
 845            })
 846            .collect::<String>();
 847
 848        assert!(text.contains("Ding"));
 849    });
 850}
 851
 852#[gpui::test]
 853async fn test_profiles(cx: &mut TestAppContext) {
 854    let ThreadTest {
 855        model, thread, fs, ..
 856    } = setup(cx, TestModel::Fake).await;
 857    let fake_model = model.as_fake();
 858
 859    thread.update(cx, |thread, _cx| {
 860        thread.add_tool(DelayTool);
 861        thread.add_tool(EchoTool);
 862        thread.add_tool(InfiniteTool);
 863    });
 864
 865    // Override profiles and wait for settings to be loaded.
 866    fs.insert_file(
 867        paths::settings_file(),
 868        json!({
 869            "agent": {
 870                "profiles": {
 871                    "test-1": {
 872                        "name": "Test Profile 1",
 873                        "tools": {
 874                            EchoTool::name(): true,
 875                            DelayTool::name(): true,
 876                        }
 877                    },
 878                    "test-2": {
 879                        "name": "Test Profile 2",
 880                        "tools": {
 881                            InfiniteTool::name(): true,
 882                        }
 883                    }
 884                }
 885            }
 886        })
 887        .to_string()
 888        .into_bytes(),
 889    )
 890    .await;
 891    cx.run_until_parked();
 892
 893    // Test that test-1 profile (default) has echo and delay tools
 894    thread
 895        .update(cx, |thread, cx| {
 896            thread.set_profile(AgentProfileId("test-1".into()));
 897            thread.send(new_prompt_id(), ["test"], cx)
 898        })
 899        .unwrap();
 900    cx.run_until_parked();
 901
 902    let mut pending_completions = fake_model.pending_completions();
 903    assert_eq!(pending_completions.len(), 1);
 904    let completion = pending_completions.pop().unwrap();
 905    let tool_names: Vec<String> = completion
 906        .tools
 907        .iter()
 908        .map(|tool| tool.name.clone())
 909        .collect();
 910    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 911    fake_model.end_last_completion_stream();
 912
 913    // Switch to test-2 profile, and verify that it has only the infinite tool.
 914    thread
 915        .update(cx, |thread, cx| {
 916            thread.set_profile(AgentProfileId("test-2".into()));
 917            thread.send(new_prompt_id(), ["test2"], cx)
 918        })
 919        .unwrap();
 920    cx.run_until_parked();
 921    let mut pending_completions = fake_model.pending_completions();
 922    assert_eq!(pending_completions.len(), 1);
 923    let completion = pending_completions.pop().unwrap();
 924    let tool_names: Vec<String> = completion
 925        .tools
 926        .iter()
 927        .map(|tool| tool.name.clone())
 928        .collect();
 929    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 930}
 931
 932#[gpui::test]
 933async fn test_mcp_tools(cx: &mut TestAppContext) {
 934    let ThreadTest {
 935        model,
 936        thread,
 937        context_server_store,
 938        fs,
 939        ..
 940    } = setup(cx, TestModel::Fake).await;
 941    let fake_model = model.as_fake();
 942
 943    // Override profiles and wait for settings to be loaded.
 944    fs.insert_file(
 945        paths::settings_file(),
 946        json!({
 947            "agent": {
 948                "profiles": {
 949                    "test": {
 950                        "name": "Test Profile",
 951                        "enable_all_context_servers": true,
 952                        "tools": {
 953                            EchoTool::name(): true,
 954                        }
 955                    },
 956                }
 957            }
 958        })
 959        .to_string()
 960        .into_bytes(),
 961    )
 962    .await;
 963    cx.run_until_parked();
 964    thread.update(cx, |thread, _| {
 965        thread.set_profile(AgentProfileId("test".into()))
 966    });
 967
 968    let mut mcp_tool_calls = setup_context_server(
 969        "test_server",
 970        vec![context_server::types::Tool {
 971            name: "echo".into(),
 972            description: None,
 973            input_schema: serde_json::to_value(
 974                EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
 975            )
 976            .unwrap(),
 977            output_schema: None,
 978            annotations: None,
 979        }],
 980        &context_server_store,
 981        cx,
 982    );
 983
 984    let events = thread.update(cx, |thread, cx| {
 985        thread.send(new_prompt_id(), ["Hey"], cx).unwrap()
 986    });
 987    cx.run_until_parked();
 988
 989    // Simulate the model calling the MCP tool.
 990    let completion = fake_model.pending_completions().pop().unwrap();
 991    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
 992    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 993        LanguageModelToolUse {
 994            id: "tool_1".into(),
 995            name: "echo".into(),
 996            raw_input: json!({"text": "test"}).to_string(),
 997            input: json!({"text": "test"}),
 998            is_input_complete: true,
 999        },
1000    ));
1001    fake_model.end_last_completion_stream();
1002    cx.run_until_parked();
1003
1004    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1005    assert_eq!(tool_call_params.name, "echo");
1006    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1007    tool_call_response
1008        .send(context_server::types::CallToolResponse {
1009            content: vec![context_server::types::ToolResponseContent::Text {
1010                text: "test".into(),
1011            }],
1012            is_error: None,
1013            meta: None,
1014            structured_content: None,
1015        })
1016        .unwrap();
1017    cx.run_until_parked();
1018
1019    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1020    fake_model.send_last_completion_stream_text_chunk("Done!");
1021    fake_model.end_last_completion_stream();
1022    events.collect::<Vec<_>>().await;
1023
1024    // Send again after adding the echo tool, ensuring the name collision is resolved.
1025    let events = thread.update(cx, |thread, cx| {
1026        thread.add_tool(EchoTool);
1027        thread.send(new_prompt_id(), ["Go"], cx).unwrap()
1028    });
1029    cx.run_until_parked();
1030    let completion = fake_model.pending_completions().pop().unwrap();
1031    assert_eq!(
1032        tool_names_for_completion(&completion),
1033        vec!["echo", "test_server_echo"]
1034    );
1035    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1036        LanguageModelToolUse {
1037            id: "tool_2".into(),
1038            name: "test_server_echo".into(),
1039            raw_input: json!({"text": "mcp"}).to_string(),
1040            input: json!({"text": "mcp"}),
1041            is_input_complete: true,
1042        },
1043    ));
1044    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1045        LanguageModelToolUse {
1046            id: "tool_3".into(),
1047            name: "echo".into(),
1048            raw_input: json!({"text": "native"}).to_string(),
1049            input: json!({"text": "native"}),
1050            is_input_complete: true,
1051        },
1052    ));
1053    fake_model.end_last_completion_stream();
1054    cx.run_until_parked();
1055
1056    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1057    assert_eq!(tool_call_params.name, "echo");
1058    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1059    tool_call_response
1060        .send(context_server::types::CallToolResponse {
1061            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1062            is_error: None,
1063            meta: None,
1064            structured_content: None,
1065        })
1066        .unwrap();
1067    cx.run_until_parked();
1068
1069    // Ensure the tool results were inserted with the correct names.
1070    let completion = fake_model.pending_completions().pop().unwrap();
1071    assert_eq!(
1072        completion.messages.last().unwrap().content,
1073        vec![
1074            MessageContent::ToolResult(LanguageModelToolResult {
1075                tool_use_id: "tool_3".into(),
1076                tool_name: "echo".into(),
1077                is_error: false,
1078                content: "native".into(),
1079                output: Some("native".into()),
1080            },),
1081            MessageContent::ToolResult(LanguageModelToolResult {
1082                tool_use_id: "tool_2".into(),
1083                tool_name: "test_server_echo".into(),
1084                is_error: false,
1085                content: "mcp".into(),
1086                output: Some("mcp".into()),
1087            },),
1088        ]
1089    );
1090    fake_model.end_last_completion_stream();
1091    events.collect::<Vec<_>>().await;
1092}
1093
1094#[gpui::test]
1095async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1096    let ThreadTest {
1097        model,
1098        thread,
1099        context_server_store,
1100        fs,
1101        ..
1102    } = setup(cx, TestModel::Fake).await;
1103    let fake_model = model.as_fake();
1104
1105    // Set up a profile with all tools enabled
1106    fs.insert_file(
1107        paths::settings_file(),
1108        json!({
1109            "agent": {
1110                "profiles": {
1111                    "test": {
1112                        "name": "Test Profile",
1113                        "enable_all_context_servers": true,
1114                        "tools": {
1115                            EchoTool::name(): true,
1116                            DelayTool::name(): true,
1117                            WordListTool::name(): true,
1118                            ToolRequiringPermission::name(): true,
1119                            InfiniteTool::name(): true,
1120                        }
1121                    },
1122                }
1123            }
1124        })
1125        .to_string()
1126        .into_bytes(),
1127    )
1128    .await;
1129    cx.run_until_parked();
1130
1131    thread.update(cx, |thread, _| {
1132        thread.set_profile(AgentProfileId("test".into()));
1133        thread.add_tool(EchoTool);
1134        thread.add_tool(DelayTool);
1135        thread.add_tool(WordListTool);
1136        thread.add_tool(ToolRequiringPermission);
1137        thread.add_tool(InfiniteTool);
1138    });
1139
1140    // Set up multiple context servers with some overlapping tool names
1141    let _server1_calls = setup_context_server(
1142        "xxx",
1143        vec![
1144            context_server::types::Tool {
1145                name: "echo".into(), // Conflicts with native EchoTool
1146                description: None,
1147                input_schema: serde_json::to_value(
1148                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1149                )
1150                .unwrap(),
1151                output_schema: None,
1152                annotations: None,
1153            },
1154            context_server::types::Tool {
1155                name: "unique_tool_1".into(),
1156                description: None,
1157                input_schema: json!({"type": "object", "properties": {}}),
1158                output_schema: None,
1159                annotations: None,
1160            },
1161        ],
1162        &context_server_store,
1163        cx,
1164    );
1165
1166    let _server2_calls = setup_context_server(
1167        "yyy",
1168        vec![
1169            context_server::types::Tool {
1170                name: "echo".into(), // Also conflicts with native EchoTool
1171                description: None,
1172                input_schema: serde_json::to_value(
1173                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1174                )
1175                .unwrap(),
1176                output_schema: None,
1177                annotations: None,
1178            },
1179            context_server::types::Tool {
1180                name: "unique_tool_2".into(),
1181                description: None,
1182                input_schema: json!({"type": "object", "properties": {}}),
1183                output_schema: None,
1184                annotations: None,
1185            },
1186            context_server::types::Tool {
1187                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1188                description: None,
1189                input_schema: json!({"type": "object", "properties": {}}),
1190                output_schema: None,
1191                annotations: None,
1192            },
1193            context_server::types::Tool {
1194                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1195                description: None,
1196                input_schema: json!({"type": "object", "properties": {}}),
1197                output_schema: None,
1198                annotations: None,
1199            },
1200        ],
1201        &context_server_store,
1202        cx,
1203    );
1204    let _server3_calls = setup_context_server(
1205        "zzz",
1206        vec![
1207            context_server::types::Tool {
1208                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1209                description: None,
1210                input_schema: json!({"type": "object", "properties": {}}),
1211                output_schema: None,
1212                annotations: None,
1213            },
1214            context_server::types::Tool {
1215                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1216                description: None,
1217                input_schema: json!({"type": "object", "properties": {}}),
1218                output_schema: None,
1219                annotations: None,
1220            },
1221            context_server::types::Tool {
1222                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1223                description: None,
1224                input_schema: json!({"type": "object", "properties": {}}),
1225                output_schema: None,
1226                annotations: None,
1227            },
1228        ],
1229        &context_server_store,
1230        cx,
1231    );
1232
1233    thread
1234        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Go"], cx))
1235        .unwrap();
1236    cx.run_until_parked();
1237    let completion = fake_model.pending_completions().pop().unwrap();
1238    assert_eq!(
1239        tool_names_for_completion(&completion),
1240        vec![
1241            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1242            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1243            "delay",
1244            "echo",
1245            "infinite",
1246            "tool_requiring_permission",
1247            "unique_tool_1",
1248            "unique_tool_2",
1249            "word_list",
1250            "xxx_echo",
1251            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1252            "yyy_echo",
1253            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1254        ]
1255    );
1256}
1257
1258#[gpui::test]
1259#[cfg_attr(not(feature = "e2e"), ignore)]
1260async fn test_cancellation(cx: &mut TestAppContext) {
1261    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1262
1263    let mut events = thread
1264        .update(cx, |thread, cx| {
1265            thread.add_tool(InfiniteTool);
1266            thread.add_tool(EchoTool);
1267            thread.send(
1268                new_prompt_id(),
1269                ["Call the echo tool, then call the infinite tool, then explain their output"],
1270                cx,
1271            )
1272        })
1273        .unwrap();
1274
1275    // Wait until both tools are called.
1276    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1277    let mut echo_id = None;
1278    let mut echo_completed = false;
1279    while let Some(event) = events.next().await {
1280        match event.unwrap() {
1281            ThreadEvent::ToolCall(tool_call) => {
1282                assert_eq!(tool_call.title, expected_tools.remove(0));
1283                if tool_call.title == "Echo" {
1284                    echo_id = Some(tool_call.id);
1285                }
1286            }
1287            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1288                acp::ToolCallUpdate {
1289                    id,
1290                    fields:
1291                        acp::ToolCallUpdateFields {
1292                            status: Some(acp::ToolCallStatus::Completed),
1293                            ..
1294                        },
1295                },
1296            )) if Some(&id) == echo_id.as_ref() => {
1297                echo_completed = true;
1298            }
1299            _ => {}
1300        }
1301
1302        if expected_tools.is_empty() && echo_completed {
1303            break;
1304        }
1305    }
1306
1307    // Cancel the current send and ensure that the event stream is closed, even
1308    // if one of the tools is still running.
1309    thread.update(cx, |thread, cx| thread.cancel(cx));
1310    let events = events.collect::<Vec<_>>().await;
1311    let last_event = events.last();
1312    assert!(
1313        matches!(
1314            last_event,
1315            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1316        ),
1317        "unexpected event {last_event:?}"
1318    );
1319
1320    // Ensure we can still send a new message after cancellation.
1321    let events = thread
1322        .update(cx, |thread, cx| {
1323            thread.send(
1324                new_prompt_id(),
1325                ["Testing: reply with 'Hello' then stop."],
1326                cx,
1327            )
1328        })
1329        .unwrap()
1330        .collect::<Vec<_>>()
1331        .await;
1332    thread.update(cx, |thread, _cx| {
1333        let message = thread.last_message().unwrap();
1334        let agent_message = message.as_agent_message().unwrap();
1335        assert_eq!(
1336            agent_message.content,
1337            vec![AgentMessageContent::Text("Hello".to_string())]
1338        );
1339    });
1340    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1341}
1342
1343#[gpui::test]
1344async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1345    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1346    let fake_model = model.as_fake();
1347
1348    let events_1 = thread
1349        .update(cx, |thread, cx| {
1350            thread.send(new_prompt_id(), ["Hello 1"], cx)
1351        })
1352        .unwrap();
1353    cx.run_until_parked();
1354    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1355    cx.run_until_parked();
1356
1357    let events_2 = thread
1358        .update(cx, |thread, cx| {
1359            thread.send(new_prompt_id(), ["Hello 2"], cx)
1360        })
1361        .unwrap();
1362    cx.run_until_parked();
1363    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1364    fake_model
1365        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1366    fake_model.end_last_completion_stream();
1367
1368    let events_1 = events_1.collect::<Vec<_>>().await;
1369    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1370    let events_2 = events_2.collect::<Vec<_>>().await;
1371    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1372}
1373
1374#[gpui::test]
1375async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1376    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1377    let fake_model = model.as_fake();
1378
1379    let events_1 = thread
1380        .update(cx, |thread, cx| {
1381            thread.send(new_prompt_id(), ["Hello 1"], cx)
1382        })
1383        .unwrap();
1384    cx.run_until_parked();
1385    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1386    fake_model
1387        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1388    fake_model.end_last_completion_stream();
1389    let events_1 = events_1.collect::<Vec<_>>().await;
1390
1391    let events_2 = thread
1392        .update(cx, |thread, cx| {
1393            thread.send(new_prompt_id(), ["Hello 2"], cx)
1394        })
1395        .unwrap();
1396    cx.run_until_parked();
1397    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1398    fake_model
1399        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1400    fake_model.end_last_completion_stream();
1401    let events_2 = events_2.collect::<Vec<_>>().await;
1402
1403    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1404    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1405}
1406
1407#[gpui::test]
1408async fn test_refusal(cx: &mut TestAppContext) {
1409    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1410    let fake_model = model.as_fake();
1411
1412    let events = thread
1413        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
1414        .unwrap();
1415    cx.run_until_parked();
1416    thread.read_with(cx, |thread, _| {
1417        assert_eq!(
1418            thread.to_markdown(),
1419            indoc! {"
1420                ## User
1421
1422                Hello
1423            "}
1424        );
1425    });
1426
1427    fake_model.send_last_completion_stream_text_chunk("Hey!");
1428    cx.run_until_parked();
1429    thread.read_with(cx, |thread, _| {
1430        assert_eq!(
1431            thread.to_markdown(),
1432            indoc! {"
1433                ## User
1434
1435                Hello
1436
1437                ## Assistant
1438
1439                Hey!
1440            "}
1441        );
1442    });
1443
1444    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1445    fake_model
1446        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1447    let events = events.collect::<Vec<_>>().await;
1448    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1449    thread.read_with(cx, |thread, _| {
1450        assert_eq!(thread.to_markdown(), "");
1451    });
1452}
1453
1454#[gpui::test]
1455async fn test_truncate_first_message(cx: &mut TestAppContext) {
1456    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1457    let fake_model = model.as_fake();
1458
1459    let message_id = new_prompt_id();
1460    thread
1461        .update(cx, |thread, cx| {
1462            thread.send(message_id.clone(), ["Hello"], cx)
1463        })
1464        .unwrap();
1465    cx.run_until_parked();
1466    thread.read_with(cx, |thread, _| {
1467        assert_eq!(
1468            thread.to_markdown(),
1469            indoc! {"
1470                ## User
1471
1472                Hello
1473            "}
1474        );
1475        assert_eq!(thread.latest_token_usage(), None);
1476    });
1477
1478    fake_model.send_last_completion_stream_text_chunk("Hey!");
1479    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1480        language_model::TokenUsage {
1481            input_tokens: 32_000,
1482            output_tokens: 16_000,
1483            cache_creation_input_tokens: 0,
1484            cache_read_input_tokens: 0,
1485        },
1486    ));
1487    cx.run_until_parked();
1488    thread.read_with(cx, |thread, _| {
1489        assert_eq!(
1490            thread.to_markdown(),
1491            indoc! {"
1492                ## User
1493
1494                Hello
1495
1496                ## Assistant
1497
1498                Hey!
1499            "}
1500        );
1501        assert_eq!(
1502            thread.latest_token_usage(),
1503            Some(acp_thread::TokenUsage {
1504                used_tokens: 32_000 + 16_000,
1505                max_tokens: 1_000_000,
1506            })
1507        );
1508    });
1509
1510    thread
1511        .update(cx, |thread, cx| thread.rewind(message_id, cx))
1512        .unwrap();
1513    cx.run_until_parked();
1514    thread.read_with(cx, |thread, _| {
1515        assert_eq!(thread.to_markdown(), "");
1516        assert_eq!(thread.latest_token_usage(), None);
1517    });
1518
1519    // Ensure we can still send a new message after truncation.
1520    thread
1521        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hi"], cx))
1522        .unwrap();
1523    thread.update(cx, |thread, _cx| {
1524        assert_eq!(
1525            thread.to_markdown(),
1526            indoc! {"
1527                ## User
1528
1529                Hi
1530            "}
1531        );
1532    });
1533    cx.run_until_parked();
1534    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1535    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1536        language_model::TokenUsage {
1537            input_tokens: 40_000,
1538            output_tokens: 20_000,
1539            cache_creation_input_tokens: 0,
1540            cache_read_input_tokens: 0,
1541        },
1542    ));
1543    cx.run_until_parked();
1544    thread.read_with(cx, |thread, _| {
1545        assert_eq!(
1546            thread.to_markdown(),
1547            indoc! {"
1548                ## User
1549
1550                Hi
1551
1552                ## Assistant
1553
1554                Ahoy!
1555            "}
1556        );
1557
1558        assert_eq!(
1559            thread.latest_token_usage(),
1560            Some(acp_thread::TokenUsage {
1561                used_tokens: 40_000 + 20_000,
1562                max_tokens: 1_000_000,
1563            })
1564        );
1565    });
1566}
1567
1568#[gpui::test]
1569async fn test_truncate_second_message(cx: &mut TestAppContext) {
1570    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1571    let fake_model = model.as_fake();
1572
1573    thread
1574        .update(cx, |thread, cx| {
1575            thread.send(new_prompt_id(), ["Message 1"], cx)
1576        })
1577        .unwrap();
1578    cx.run_until_parked();
1579    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1580    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1581        language_model::TokenUsage {
1582            input_tokens: 32_000,
1583            output_tokens: 16_000,
1584            cache_creation_input_tokens: 0,
1585            cache_read_input_tokens: 0,
1586        },
1587    ));
1588    fake_model.end_last_completion_stream();
1589    cx.run_until_parked();
1590
1591    let assert_first_message_state = |cx: &mut TestAppContext| {
1592        thread.clone().read_with(cx, |thread, _| {
1593            assert_eq!(
1594                thread.to_markdown(),
1595                indoc! {"
1596                    ## User
1597
1598                    Message 1
1599
1600                    ## Assistant
1601
1602                    Message 1 response
1603                "}
1604            );
1605
1606            assert_eq!(
1607                thread.latest_token_usage(),
1608                Some(acp_thread::TokenUsage {
1609                    used_tokens: 32_000 + 16_000,
1610                    max_tokens: 1_000_000,
1611                })
1612            );
1613        });
1614    };
1615
1616    assert_first_message_state(cx);
1617
1618    let second_message_id = new_prompt_id();
1619    thread
1620        .update(cx, |thread, cx| {
1621            thread.send(second_message_id.clone(), ["Message 2"], cx)
1622        })
1623        .unwrap();
1624    cx.run_until_parked();
1625
1626    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1627    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1628        language_model::TokenUsage {
1629            input_tokens: 40_000,
1630            output_tokens: 20_000,
1631            cache_creation_input_tokens: 0,
1632            cache_read_input_tokens: 0,
1633        },
1634    ));
1635    fake_model.end_last_completion_stream();
1636    cx.run_until_parked();
1637
1638    thread.read_with(cx, |thread, _| {
1639        assert_eq!(
1640            thread.to_markdown(),
1641            indoc! {"
1642                ## User
1643
1644                Message 1
1645
1646                ## Assistant
1647
1648                Message 1 response
1649
1650                ## User
1651
1652                Message 2
1653
1654                ## Assistant
1655
1656                Message 2 response
1657            "}
1658        );
1659
1660        assert_eq!(
1661            thread.latest_token_usage(),
1662            Some(acp_thread::TokenUsage {
1663                used_tokens: 40_000 + 20_000,
1664                max_tokens: 1_000_000,
1665            })
1666        );
1667    });
1668
1669    thread
1670        .update(cx, |thread, cx| thread.rewind(second_message_id, cx))
1671        .unwrap();
1672    cx.run_until_parked();
1673
1674    assert_first_message_state(cx);
1675}
1676
1677#[gpui::test]
1678#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1679async fn test_title_generation(cx: &mut TestAppContext) {
1680    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1681    let fake_model = model.as_fake();
1682
1683    let summary_model = Arc::new(FakeLanguageModel::default());
1684    thread.update(cx, |thread, cx| {
1685        thread.set_summarization_model(Some(summary_model.clone()), cx)
1686    });
1687
1688    let send = thread
1689        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
1690        .unwrap();
1691    cx.run_until_parked();
1692
1693    fake_model.send_last_completion_stream_text_chunk("Hey!");
1694    fake_model.end_last_completion_stream();
1695    cx.run_until_parked();
1696    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1697
1698    // Ensure the summary model has been invoked to generate a title.
1699    summary_model.send_last_completion_stream_text_chunk("Hello ");
1700    summary_model.send_last_completion_stream_text_chunk("world\nG");
1701    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1702    summary_model.end_last_completion_stream();
1703    send.collect::<Vec<_>>().await;
1704    cx.run_until_parked();
1705    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1706
1707    // Send another message, ensuring no title is generated this time.
1708    let send = thread
1709        .update(cx, |thread, cx| {
1710            thread.send(new_prompt_id(), ["Hello again"], cx)
1711        })
1712        .unwrap();
1713    cx.run_until_parked();
1714    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1715    fake_model.end_last_completion_stream();
1716    cx.run_until_parked();
1717    assert_eq!(summary_model.pending_completions(), Vec::new());
1718    send.collect::<Vec<_>>().await;
1719    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1720}
1721
1722#[gpui::test]
1723async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1724    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1725    let fake_model = model.as_fake();
1726
1727    let _events = thread
1728        .update(cx, |thread, cx| {
1729            thread.add_tool(ToolRequiringPermission);
1730            thread.add_tool(EchoTool);
1731            thread.send(new_prompt_id(), ["Hey!"], cx)
1732        })
1733        .unwrap();
1734    cx.run_until_parked();
1735
1736    let permission_tool_use = LanguageModelToolUse {
1737        id: "tool_id_1".into(),
1738        name: ToolRequiringPermission::name().into(),
1739        raw_input: "{}".into(),
1740        input: json!({}),
1741        is_input_complete: true,
1742    };
1743    let echo_tool_use = LanguageModelToolUse {
1744        id: "tool_id_2".into(),
1745        name: EchoTool::name().into(),
1746        raw_input: json!({"text": "test"}).to_string(),
1747        input: json!({"text": "test"}),
1748        is_input_complete: true,
1749    };
1750    fake_model.send_last_completion_stream_text_chunk("Hi!");
1751    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1752        permission_tool_use,
1753    ));
1754    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1755        echo_tool_use.clone(),
1756    ));
1757    fake_model.end_last_completion_stream();
1758    cx.run_until_parked();
1759
1760    // Ensure pending tools are skipped when building a request.
1761    let request = thread
1762        .read_with(cx, |thread, cx| {
1763            thread.build_completion_request(CompletionIntent::EditFile, cx)
1764        })
1765        .unwrap();
1766    assert_eq!(
1767        request.messages[1..],
1768        vec![
1769            LanguageModelRequestMessage {
1770                role: Role::User,
1771                content: vec!["Hey!".into()],
1772                cache: true
1773            },
1774            LanguageModelRequestMessage {
1775                role: Role::Assistant,
1776                content: vec![
1777                    MessageContent::Text("Hi!".into()),
1778                    MessageContent::ToolUse(echo_tool_use.clone())
1779                ],
1780                cache: false
1781            },
1782            LanguageModelRequestMessage {
1783                role: Role::User,
1784                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1785                    tool_use_id: echo_tool_use.id.clone(),
1786                    tool_name: echo_tool_use.name,
1787                    is_error: false,
1788                    content: "test".into(),
1789                    output: Some("test".into())
1790                })],
1791                cache: false
1792            },
1793        ],
1794    );
1795}
1796
1797#[gpui::test]
1798async fn test_agent_connection(cx: &mut TestAppContext) {
1799    cx.update(settings::init);
1800    let templates = Templates::new();
1801
1802    // Initialize language model system with test provider
1803    cx.update(|cx| {
1804        gpui_tokio::init(cx);
1805        client::init_settings(cx);
1806
1807        let http_client = FakeHttpClient::with_404_response();
1808        let clock = Arc::new(clock::FakeSystemClock::new());
1809        let client = Client::new(clock, http_client, cx);
1810        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1811        Project::init_settings(cx);
1812        agent_settings::init(cx);
1813        language_model::init(client.clone(), cx);
1814        language_models::init(user_store, client.clone(), cx);
1815        LanguageModelRegistry::test(cx);
1816    });
1817    cx.executor().forbid_parking();
1818
1819    // Create a project for new_thread
1820    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1821    fake_fs.insert_tree(path!("/test"), json!({})).await;
1822    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1823    let cwd = Path::new("/test");
1824    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1825    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1826
1827    // Create agent and connection
1828    let agent = NativeAgent::new(
1829        project.clone(),
1830        history_store,
1831        templates.clone(),
1832        None,
1833        fake_fs.clone(),
1834        &mut cx.to_async(),
1835    )
1836    .await
1837    .unwrap();
1838    let connection = NativeAgentConnection(agent.clone());
1839
1840    // Test model_selector returns Some
1841    let selector_opt = connection.model_selector();
1842    assert!(
1843        selector_opt.is_some(),
1844        "agent2 should always support ModelSelector"
1845    );
1846    let selector = selector_opt.unwrap();
1847
1848    // Test list_models
1849    let listed_models = cx
1850        .update(|cx| selector.list_models(cx))
1851        .await
1852        .expect("list_models should succeed");
1853    let AgentModelList::Grouped(listed_models) = listed_models else {
1854        panic!("Unexpected model list type");
1855    };
1856    assert!(!listed_models.is_empty(), "should have at least one model");
1857    assert_eq!(
1858        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1859        "fake/fake"
1860    );
1861
1862    // Create a thread using new_thread
1863    let connection_rc = Rc::new(connection.clone());
1864    let acp_thread = cx
1865        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1866        .await
1867        .expect("new_thread should succeed");
1868
1869    // Get the session_id from the AcpThread
1870    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1871
1872    // Test selected_model returns the default
1873    let model = cx
1874        .update(|cx| selector.selected_model(&session_id, cx))
1875        .await
1876        .expect("selected_model should succeed");
1877    let model = cx
1878        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1879        .unwrap();
1880    let model = model.as_fake();
1881    assert_eq!(model.id().0, "fake", "should return default model");
1882
1883    let request = acp_thread.update(cx, |thread, cx| {
1884        thread.send(new_prompt_id(), vec!["abc".into()], cx)
1885    });
1886    cx.run_until_parked();
1887    model.send_last_completion_stream_text_chunk("def");
1888    cx.run_until_parked();
1889    acp_thread.read_with(cx, |thread, cx| {
1890        assert_eq!(
1891            thread.to_markdown(cx),
1892            indoc! {"
1893                ## User
1894
1895                abc
1896
1897                ## Assistant
1898
1899                def
1900
1901            "}
1902        )
1903    });
1904
1905    // Test cancel
1906    cx.update(|cx| connection.cancel(&session_id, cx));
1907    request.await.expect("prompt should fail gracefully");
1908
1909    // Ensure that dropping the ACP thread causes the native thread to be
1910    // dropped as well.
1911    cx.update(|_| drop(acp_thread));
1912    let result = cx
1913        .update(|cx| {
1914            connection.prompt(
1915                acp::PromptRequest {
1916                    prompt_id: Some(new_prompt_id()),
1917                    session_id: session_id.clone(),
1918                    prompt: vec!["ghi".into()],
1919                },
1920                cx,
1921            )
1922        })
1923        .await;
1924    assert_eq!(
1925        result.as_ref().unwrap_err().to_string(),
1926        "Session not found",
1927        "unexpected result: {:?}",
1928        result
1929    );
1930}
1931
1932#[gpui::test]
1933async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1934    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1935    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1936    let fake_model = model.as_fake();
1937
1938    let mut events = thread
1939        .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Think"], cx))
1940        .unwrap();
1941    cx.run_until_parked();
1942
1943    // Simulate streaming partial input.
1944    let input = json!({});
1945    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1946        LanguageModelToolUse {
1947            id: "1".into(),
1948            name: ThinkingTool::name().into(),
1949            raw_input: input.to_string(),
1950            input,
1951            is_input_complete: false,
1952        },
1953    ));
1954
1955    // Input streaming completed
1956    let input = json!({ "content": "Thinking hard!" });
1957    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1958        LanguageModelToolUse {
1959            id: "1".into(),
1960            name: "thinking".into(),
1961            raw_input: input.to_string(),
1962            input,
1963            is_input_complete: true,
1964        },
1965    ));
1966    fake_model.end_last_completion_stream();
1967    cx.run_until_parked();
1968
1969    let tool_call = expect_tool_call(&mut events).await;
1970    assert_eq!(
1971        tool_call,
1972        acp::ToolCall {
1973            id: acp::ToolCallId("1".into()),
1974            title: "Thinking".into(),
1975            kind: acp::ToolKind::Think,
1976            status: acp::ToolCallStatus::Pending,
1977            content: vec![],
1978            locations: vec![],
1979            raw_input: Some(json!({})),
1980            raw_output: None,
1981        }
1982    );
1983    let update = expect_tool_call_update_fields(&mut events).await;
1984    assert_eq!(
1985        update,
1986        acp::ToolCallUpdate {
1987            id: acp::ToolCallId("1".into()),
1988            fields: acp::ToolCallUpdateFields {
1989                title: Some("Thinking".into()),
1990                kind: Some(acp::ToolKind::Think),
1991                raw_input: Some(json!({ "content": "Thinking hard!" })),
1992                ..Default::default()
1993            },
1994        }
1995    );
1996    let update = expect_tool_call_update_fields(&mut events).await;
1997    assert_eq!(
1998        update,
1999        acp::ToolCallUpdate {
2000            id: acp::ToolCallId("1".into()),
2001            fields: acp::ToolCallUpdateFields {
2002                status: Some(acp::ToolCallStatus::InProgress),
2003                ..Default::default()
2004            },
2005        }
2006    );
2007    let update = expect_tool_call_update_fields(&mut events).await;
2008    assert_eq!(
2009        update,
2010        acp::ToolCallUpdate {
2011            id: acp::ToolCallId("1".into()),
2012            fields: acp::ToolCallUpdateFields {
2013                content: Some(vec!["Thinking hard!".into()]),
2014                ..Default::default()
2015            },
2016        }
2017    );
2018    let update = expect_tool_call_update_fields(&mut events).await;
2019    assert_eq!(
2020        update,
2021        acp::ToolCallUpdate {
2022            id: acp::ToolCallId("1".into()),
2023            fields: acp::ToolCallUpdateFields {
2024                status: Some(acp::ToolCallStatus::Completed),
2025                raw_output: Some("Finished thinking.".into()),
2026                ..Default::default()
2027            },
2028        }
2029    );
2030}
2031
2032#[gpui::test]
2033async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2034    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2035    let fake_model = model.as_fake();
2036
2037    let mut events = thread
2038        .update(cx, |thread, cx| {
2039            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2040            thread.send(new_prompt_id(), ["Hello!"], cx)
2041        })
2042        .unwrap();
2043    cx.run_until_parked();
2044
2045    fake_model.send_last_completion_stream_text_chunk("Hey!");
2046    fake_model.end_last_completion_stream();
2047
2048    let mut retry_events = Vec::new();
2049    while let Some(Ok(event)) = events.next().await {
2050        match event {
2051            ThreadEvent::Retry(retry_status) => {
2052                retry_events.push(retry_status);
2053            }
2054            ThreadEvent::Stop(..) => break,
2055            _ => {}
2056        }
2057    }
2058
2059    assert_eq!(retry_events.len(), 0);
2060    thread.read_with(cx, |thread, _cx| {
2061        assert_eq!(
2062            thread.to_markdown(),
2063            indoc! {"
2064                ## User
2065
2066                Hello!
2067
2068                ## Assistant
2069
2070                Hey!
2071            "}
2072        )
2073    });
2074}
2075
2076#[gpui::test]
2077async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2078    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2079    let fake_model = model.as_fake();
2080
2081    let mut events = thread
2082        .update(cx, |thread, cx| {
2083            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2084            thread.send(new_prompt_id(), ["Hello!"], cx)
2085        })
2086        .unwrap();
2087    cx.run_until_parked();
2088
2089    fake_model.send_last_completion_stream_text_chunk("Hey,");
2090    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2091        provider: LanguageModelProviderName::new("Anthropic"),
2092        retry_after: Some(Duration::from_secs(3)),
2093    });
2094    fake_model.end_last_completion_stream();
2095
2096    cx.executor().advance_clock(Duration::from_secs(3));
2097    cx.run_until_parked();
2098
2099    fake_model.send_last_completion_stream_text_chunk("there!");
2100    fake_model.end_last_completion_stream();
2101    cx.run_until_parked();
2102
2103    let mut retry_events = Vec::new();
2104    while let Some(Ok(event)) = events.next().await {
2105        match event {
2106            ThreadEvent::Retry(retry_status) => {
2107                retry_events.push(retry_status);
2108            }
2109            ThreadEvent::Stop(..) => break,
2110            _ => {}
2111        }
2112    }
2113
2114    assert_eq!(retry_events.len(), 1);
2115    assert!(matches!(
2116        retry_events[0],
2117        acp_thread::RetryStatus { attempt: 1, .. }
2118    ));
2119    thread.read_with(cx, |thread, _cx| {
2120        assert_eq!(
2121            thread.to_markdown(),
2122            indoc! {"
2123                ## User
2124
2125                Hello!
2126
2127                ## Assistant
2128
2129                Hey,
2130
2131                [resume]
2132
2133                ## Assistant
2134
2135                there!
2136            "}
2137        )
2138    });
2139}
2140
2141#[gpui::test]
2142async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2143    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2144    let fake_model = model.as_fake();
2145
2146    let events = thread
2147        .update(cx, |thread, cx| {
2148            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2149            thread.add_tool(EchoTool);
2150            thread.send(new_prompt_id(), ["Call the echo tool!"], cx)
2151        })
2152        .unwrap();
2153    cx.run_until_parked();
2154
2155    let tool_use_1 = LanguageModelToolUse {
2156        id: "tool_1".into(),
2157        name: EchoTool::name().into(),
2158        raw_input: json!({"text": "test"}).to_string(),
2159        input: json!({"text": "test"}),
2160        is_input_complete: true,
2161    };
2162    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2163        tool_use_1.clone(),
2164    ));
2165    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2166        provider: LanguageModelProviderName::new("Anthropic"),
2167        retry_after: Some(Duration::from_secs(3)),
2168    });
2169    fake_model.end_last_completion_stream();
2170
2171    cx.executor().advance_clock(Duration::from_secs(3));
2172    let completion = fake_model.pending_completions().pop().unwrap();
2173    assert_eq!(
2174        completion.messages[1..],
2175        vec![
2176            LanguageModelRequestMessage {
2177                role: Role::User,
2178                content: vec!["Call the echo tool!".into()],
2179                cache: false
2180            },
2181            LanguageModelRequestMessage {
2182                role: Role::Assistant,
2183                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2184                cache: false
2185            },
2186            LanguageModelRequestMessage {
2187                role: Role::User,
2188                content: vec![language_model::MessageContent::ToolResult(
2189                    LanguageModelToolResult {
2190                        tool_use_id: tool_use_1.id.clone(),
2191                        tool_name: tool_use_1.name.clone(),
2192                        is_error: false,
2193                        content: "test".into(),
2194                        output: Some("test".into())
2195                    }
2196                )],
2197                cache: true
2198            },
2199        ]
2200    );
2201
2202    fake_model.send_last_completion_stream_text_chunk("Done");
2203    fake_model.end_last_completion_stream();
2204    cx.run_until_parked();
2205    events.collect::<Vec<_>>().await;
2206    thread.read_with(cx, |thread, _cx| {
2207        assert_eq!(
2208            thread.last_message(),
2209            Some(Message::Agent(AgentMessage {
2210                content: vec![AgentMessageContent::Text("Done".into())],
2211                tool_results: IndexMap::default()
2212            }))
2213        );
2214    })
2215}
2216
2217#[gpui::test]
2218async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2219    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2220    let fake_model = model.as_fake();
2221
2222    let mut events = thread
2223        .update(cx, |thread, cx| {
2224            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2225            thread.send(new_prompt_id(), ["Hello!"], cx)
2226        })
2227        .unwrap();
2228    cx.run_until_parked();
2229
2230    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2231        fake_model.send_last_completion_stream_error(
2232            LanguageModelCompletionError::ServerOverloaded {
2233                provider: LanguageModelProviderName::new("Anthropic"),
2234                retry_after: Some(Duration::from_secs(3)),
2235            },
2236        );
2237        fake_model.end_last_completion_stream();
2238        cx.executor().advance_clock(Duration::from_secs(3));
2239        cx.run_until_parked();
2240    }
2241
2242    let mut errors = Vec::new();
2243    let mut retry_events = Vec::new();
2244    while let Some(event) = events.next().await {
2245        match event {
2246            Ok(ThreadEvent::Retry(retry_status)) => {
2247                retry_events.push(retry_status);
2248            }
2249            Ok(ThreadEvent::Stop(..)) => break,
2250            Err(error) => errors.push(error),
2251            _ => {}
2252        }
2253    }
2254
2255    assert_eq!(
2256        retry_events.len(),
2257        crate::thread::MAX_RETRY_ATTEMPTS as usize
2258    );
2259    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2260        assert_eq!(retry_events[i].attempt, i + 1);
2261    }
2262    assert_eq!(errors.len(), 1);
2263    let error = errors[0]
2264        .downcast_ref::<LanguageModelCompletionError>()
2265        .unwrap();
2266    assert!(matches!(
2267        error,
2268        LanguageModelCompletionError::ServerOverloaded { .. }
2269    ));
2270}
2271
2272/// Filters out the stop events for asserting against in tests
2273fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2274    result_events
2275        .into_iter()
2276        .filter_map(|event| match event.unwrap() {
2277            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2278            _ => None,
2279        })
2280        .collect()
2281}
2282
2283struct ThreadTest {
2284    model: Arc<dyn LanguageModel>,
2285    thread: Entity<Thread>,
2286    project_context: Entity<ProjectContext>,
2287    context_server_store: Entity<ContextServerStore>,
2288    fs: Arc<FakeFs>,
2289}
2290
2291enum TestModel {
2292    Sonnet4,
2293    Fake,
2294}
2295
2296impl TestModel {
2297    fn id(&self) -> LanguageModelId {
2298        match self {
2299            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2300            TestModel::Fake => unreachable!(),
2301        }
2302    }
2303}
2304
2305async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2306    cx.executor().allow_parking();
2307
2308    let fs = FakeFs::new(cx.background_executor.clone());
2309    fs.create_dir(paths::settings_file().parent().unwrap())
2310        .await
2311        .unwrap();
2312    fs.insert_file(
2313        paths::settings_file(),
2314        json!({
2315            "agent": {
2316                "default_profile": "test-profile",
2317                "profiles": {
2318                    "test-profile": {
2319                        "name": "Test Profile",
2320                        "tools": {
2321                            EchoTool::name(): true,
2322                            DelayTool::name(): true,
2323                            WordListTool::name(): true,
2324                            ToolRequiringPermission::name(): true,
2325                            InfiniteTool::name(): true,
2326                            ThinkingTool::name(): true,
2327                        }
2328                    }
2329                }
2330            }
2331        })
2332        .to_string()
2333        .into_bytes(),
2334    )
2335    .await;
2336
2337    cx.update(|cx| {
2338        settings::init(cx);
2339        Project::init_settings(cx);
2340        agent_settings::init(cx);
2341        gpui_tokio::init(cx);
2342        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2343        cx.set_http_client(Arc::new(http_client));
2344
2345        client::init_settings(cx);
2346        let client = Client::production(cx);
2347        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2348        language_model::init(client.clone(), cx);
2349        language_models::init(user_store, client.clone(), cx);
2350
2351        watch_settings(fs.clone(), cx);
2352    });
2353
2354    let templates = Templates::new();
2355
2356    fs.insert_tree(path!("/test"), json!({})).await;
2357    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2358
2359    let model = cx
2360        .update(|cx| {
2361            if let TestModel::Fake = model {
2362                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2363            } else {
2364                let model_id = model.id();
2365                let models = LanguageModelRegistry::read_global(cx);
2366                let model = models
2367                    .available_models(cx)
2368                    .find(|model| model.id() == model_id)
2369                    .unwrap();
2370
2371                let provider = models.provider(&model.provider_id()).unwrap();
2372                let authenticated = provider.authenticate(cx);
2373
2374                cx.spawn(async move |_cx| {
2375                    authenticated.await.unwrap();
2376                    model
2377                })
2378            }
2379        })
2380        .await;
2381
2382    let project_context = cx.new(|_cx| ProjectContext::default());
2383    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2384    let context_server_registry =
2385        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2386    let thread = cx.new(|cx| {
2387        Thread::new(
2388            project,
2389            project_context.clone(),
2390            context_server_registry,
2391            templates,
2392            Some(model.clone()),
2393            cx,
2394        )
2395    });
2396    ThreadTest {
2397        model,
2398        thread,
2399        project_context,
2400        context_server_store,
2401        fs,
2402    }
2403}
2404
2405#[cfg(test)]
2406#[ctor::ctor]
2407fn init_logger() {
2408    if std::env::var("RUST_LOG").is_ok() {
2409        env_logger::init();
2410    }
2411}
2412
2413fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2414    let fs = fs.clone();
2415    cx.spawn({
2416        async move |cx| {
2417            let mut new_settings_content_rx = settings::watch_config_file(
2418                cx.background_executor(),
2419                fs,
2420                paths::settings_file().clone(),
2421            );
2422
2423            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2424                cx.update(|cx| {
2425                    SettingsStore::update_global(cx, |settings, cx| {
2426                        settings.set_user_settings(&new_settings_content, cx)
2427                    })
2428                })
2429                .ok();
2430            }
2431        }
2432    })
2433    .detach();
2434}
2435
2436fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2437    completion
2438        .tools
2439        .iter()
2440        .map(|tool| tool.name.clone())
2441        .collect()
2442}
2443
2444fn setup_context_server(
2445    name: &'static str,
2446    tools: Vec<context_server::types::Tool>,
2447    context_server_store: &Entity<ContextServerStore>,
2448    cx: &mut TestAppContext,
2449) -> mpsc::UnboundedReceiver<(
2450    context_server::types::CallToolParams,
2451    oneshot::Sender<context_server::types::CallToolResponse>,
2452)> {
2453    cx.update(|cx| {
2454        let mut settings = ProjectSettings::get_global(cx).clone();
2455        settings.context_servers.insert(
2456            name.into(),
2457            project::project_settings::ContextServerSettings::Custom {
2458                enabled: true,
2459                command: ContextServerCommand {
2460                    path: "somebinary".into(),
2461                    args: Vec::new(),
2462                    env: None,
2463                },
2464            },
2465        );
2466        ProjectSettings::override_global(settings, cx);
2467    });
2468
2469    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2470    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2471        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2472            context_server::types::InitializeResponse {
2473                protocol_version: context_server::types::ProtocolVersion(
2474                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2475                ),
2476                server_info: context_server::types::Implementation {
2477                    name: name.into(),
2478                    version: "1.0.0".to_string(),
2479                },
2480                capabilities: context_server::types::ServerCapabilities {
2481                    tools: Some(context_server::types::ToolsCapabilities {
2482                        list_changed: Some(true),
2483                    }),
2484                    ..Default::default()
2485                },
2486                meta: None,
2487            }
2488        })
2489        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2490            let tools = tools.clone();
2491            async move {
2492                context_server::types::ListToolsResponse {
2493                    tools,
2494                    next_cursor: None,
2495                    meta: None,
2496                }
2497            }
2498        })
2499        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2500            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2501            async move {
2502                let (response_tx, response_rx) = oneshot::channel();
2503                mcp_tool_calls_tx
2504                    .unbounded_send((params, response_tx))
2505                    .unwrap();
2506                response_rx.await.unwrap()
2507            }
2508        });
2509    context_server_store.update(cx, |store, cx| {
2510        store.start_server(
2511            Arc::new(ContextServer::new(
2512                ContextServerId(name.into()),
2513                Arc::new(fake_transport),
2514            )),
2515            cx,
2516        );
2517    });
2518    cx.run_until_parked();
2519    mcp_tool_calls_rx
2520}