mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_prompt_caching(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    // Send initial user message and verify it's cached
 169    thread
 170        .update(cx, |thread, cx| {
 171            thread.send(UserMessageId::new(), ["Message 1"], cx)
 172        })
 173        .unwrap();
 174    cx.run_until_parked();
 175
 176    let completion = fake_model.pending_completions().pop().unwrap();
 177    assert_eq!(
 178        completion.messages[1..],
 179        vec![LanguageModelRequestMessage {
 180            role: Role::User,
 181            content: vec!["Message 1".into()],
 182            cache: true
 183        }]
 184    );
 185    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 186        "Response to Message 1".into(),
 187    ));
 188    fake_model.end_last_completion_stream();
 189    cx.run_until_parked();
 190
 191    // Send another user message and verify only the latest is cached
 192    thread
 193        .update(cx, |thread, cx| {
 194            thread.send(UserMessageId::new(), ["Message 2"], cx)
 195        })
 196        .unwrap();
 197    cx.run_until_parked();
 198
 199    let completion = fake_model.pending_completions().pop().unwrap();
 200    assert_eq!(
 201        completion.messages[1..],
 202        vec![
 203            LanguageModelRequestMessage {
 204                role: Role::User,
 205                content: vec!["Message 1".into()],
 206                cache: false
 207            },
 208            LanguageModelRequestMessage {
 209                role: Role::Assistant,
 210                content: vec!["Response to Message 1".into()],
 211                cache: false
 212            },
 213            LanguageModelRequestMessage {
 214                role: Role::User,
 215                content: vec!["Message 2".into()],
 216                cache: true
 217            }
 218        ]
 219    );
 220    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 221        "Response to Message 2".into(),
 222    ));
 223    fake_model.end_last_completion_stream();
 224    cx.run_until_parked();
 225
 226    // Simulate a tool call and verify that the latest tool result is cached
 227    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let tool_use = LanguageModelToolUse {
 236        id: "tool_1".into(),
 237        name: EchoTool::name().into(),
 238        raw_input: json!({"text": "test"}).to_string(),
 239        input: json!({"text": "test"}),
 240        is_input_complete: true,
 241    };
 242    fake_model
 243        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 244    fake_model.end_last_completion_stream();
 245    cx.run_until_parked();
 246
 247    let completion = fake_model.pending_completions().pop().unwrap();
 248    let tool_result = LanguageModelToolResult {
 249        tool_use_id: "tool_1".into(),
 250        tool_name: EchoTool::name().into(),
 251        is_error: false,
 252        content: "test".into(),
 253        output: Some("test".into()),
 254    };
 255    assert_eq!(
 256        completion.messages[1..],
 257        vec![
 258            LanguageModelRequestMessage {
 259                role: Role::User,
 260                content: vec!["Message 1".into()],
 261                cache: false
 262            },
 263            LanguageModelRequestMessage {
 264                role: Role::Assistant,
 265                content: vec!["Response to Message 1".into()],
 266                cache: false
 267            },
 268            LanguageModelRequestMessage {
 269                role: Role::User,
 270                content: vec!["Message 2".into()],
 271                cache: false
 272            },
 273            LanguageModelRequestMessage {
 274                role: Role::Assistant,
 275                content: vec!["Response to Message 2".into()],
 276                cache: false
 277            },
 278            LanguageModelRequestMessage {
 279                role: Role::User,
 280                content: vec!["Use the echo tool".into()],
 281                cache: false
 282            },
 283            LanguageModelRequestMessage {
 284                role: Role::Assistant,
 285                content: vec![MessageContent::ToolUse(tool_use)],
 286                cache: false
 287            },
 288            LanguageModelRequestMessage {
 289                role: Role::User,
 290                content: vec![MessageContent::ToolResult(tool_result)],
 291                cache: true
 292            }
 293        ]
 294    );
 295}
 296
 297#[gpui::test]
 298#[cfg_attr(not(feature = "e2e"), ignore)]
 299async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 300    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 301
 302    // Test a tool call that's likely to complete *before* streaming stops.
 303    let events = thread
 304        .update(cx, |thread, cx| {
 305            thread.add_tool(EchoTool);
 306            thread.send(
 307                UserMessageId::new(),
 308                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 309                cx,
 310            )
 311        })
 312        .unwrap()
 313        .collect()
 314        .await;
 315    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 316
 317    // Test a tool calls that's likely to complete *after* streaming stops.
 318    let events = thread
 319        .update(cx, |thread, cx| {
 320            thread.remove_tool(&EchoTool::name());
 321            thread.add_tool(DelayTool);
 322            thread.send(
 323                UserMessageId::new(),
 324                [
 325                    "Now call the delay tool with 200ms.",
 326                    "When the timer goes off, then you echo the output of the tool.",
 327                ],
 328                cx,
 329            )
 330        })
 331        .unwrap()
 332        .collect()
 333        .await;
 334    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 335    thread.update(cx, |thread, _cx| {
 336        assert!(
 337            thread
 338                .last_message()
 339                .unwrap()
 340                .as_agent_message()
 341                .unwrap()
 342                .content
 343                .iter()
 344                .any(|content| {
 345                    if let AgentMessageContent::Text(text) = content {
 346                        text.contains("Ding")
 347                    } else {
 348                        false
 349                    }
 350                }),
 351            "{}",
 352            thread.to_markdown()
 353        );
 354    });
 355}
 356
 357#[gpui::test]
 358#[cfg_attr(not(feature = "e2e"), ignore)]
 359async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 360    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 361
 362    // Test a tool call that's likely to complete *before* streaming stops.
 363    let mut events = thread
 364        .update(cx, |thread, cx| {
 365            thread.add_tool(WordListTool);
 366            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 367        })
 368        .unwrap();
 369
 370    let mut saw_partial_tool_use = false;
 371    while let Some(event) = events.next().await {
 372        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 373            thread.update(cx, |thread, _cx| {
 374                // Look for a tool use in the thread's last message
 375                let message = thread.last_message().unwrap();
 376                let agent_message = message.as_agent_message().unwrap();
 377                let last_content = agent_message.content.last().unwrap();
 378                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 379                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 380                    if tool_call.status == acp::ToolCallStatus::Pending {
 381                        if !last_tool_use.is_input_complete
 382                            && last_tool_use.input.get("g").is_none()
 383                        {
 384                            saw_partial_tool_use = true;
 385                        }
 386                    } else {
 387                        last_tool_use
 388                            .input
 389                            .get("a")
 390                            .expect("'a' has streamed because input is now complete");
 391                        last_tool_use
 392                            .input
 393                            .get("g")
 394                            .expect("'g' has streamed because input is now complete");
 395                    }
 396                } else {
 397                    panic!("last content should be a tool use");
 398                }
 399            });
 400        }
 401    }
 402
 403    assert!(
 404        saw_partial_tool_use,
 405        "should see at least one partially streamed tool use in the history"
 406    );
 407}
 408
 409#[gpui::test]
 410async fn test_tool_authorization(cx: &mut TestAppContext) {
 411    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 412    let fake_model = model.as_fake();
 413
 414    let mut events = thread
 415        .update(cx, |thread, cx| {
 416            thread.add_tool(ToolRequiringPermission);
 417            thread.send(UserMessageId::new(), ["abc"], cx)
 418        })
 419        .unwrap();
 420    cx.run_until_parked();
 421    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 422        LanguageModelToolUse {
 423            id: "tool_id_1".into(),
 424            name: ToolRequiringPermission::name().into(),
 425            raw_input: "{}".into(),
 426            input: json!({}),
 427            is_input_complete: true,
 428        },
 429    ));
 430    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 431        LanguageModelToolUse {
 432            id: "tool_id_2".into(),
 433            name: ToolRequiringPermission::name().into(),
 434            raw_input: "{}".into(),
 435            input: json!({}),
 436            is_input_complete: true,
 437        },
 438    ));
 439    fake_model.end_last_completion_stream();
 440    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 441    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 442
 443    // Approve the first
 444    tool_call_auth_1
 445        .response
 446        .send(tool_call_auth_1.options[1].id.clone())
 447        .unwrap();
 448    cx.run_until_parked();
 449
 450    // Reject the second
 451    tool_call_auth_2
 452        .response
 453        .send(tool_call_auth_1.options[2].id.clone())
 454        .unwrap();
 455    cx.run_until_parked();
 456
 457    let completion = fake_model.pending_completions().pop().unwrap();
 458    let message = completion.messages.last().unwrap();
 459    assert_eq!(
 460        message.content,
 461        vec![
 462            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 463                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 464                tool_name: ToolRequiringPermission::name().into(),
 465                is_error: false,
 466                content: "Allowed".into(),
 467                output: Some("Allowed".into())
 468            }),
 469            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 470                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 471                tool_name: ToolRequiringPermission::name().into(),
 472                is_error: true,
 473                content: "Permission to run tool denied by user".into(),
 474                output: Some("Permission to run tool denied by user".into())
 475            })
 476        ]
 477    );
 478
 479    // Simulate yet another tool call.
 480    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 481        LanguageModelToolUse {
 482            id: "tool_id_3".into(),
 483            name: ToolRequiringPermission::name().into(),
 484            raw_input: "{}".into(),
 485            input: json!({}),
 486            is_input_complete: true,
 487        },
 488    ));
 489    fake_model.end_last_completion_stream();
 490
 491    // Respond by always allowing tools.
 492    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 493    tool_call_auth_3
 494        .response
 495        .send(tool_call_auth_3.options[0].id.clone())
 496        .unwrap();
 497    cx.run_until_parked();
 498    let completion = fake_model.pending_completions().pop().unwrap();
 499    let message = completion.messages.last().unwrap();
 500    assert_eq!(
 501        message.content,
 502        vec![language_model::MessageContent::ToolResult(
 503            LanguageModelToolResult {
 504                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 505                tool_name: ToolRequiringPermission::name().into(),
 506                is_error: false,
 507                content: "Allowed".into(),
 508                output: Some("Allowed".into())
 509            }
 510        )]
 511    );
 512
 513    // Simulate a final tool call, ensuring we don't trigger authorization.
 514    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 515        LanguageModelToolUse {
 516            id: "tool_id_4".into(),
 517            name: ToolRequiringPermission::name().into(),
 518            raw_input: "{}".into(),
 519            input: json!({}),
 520            is_input_complete: true,
 521        },
 522    ));
 523    fake_model.end_last_completion_stream();
 524    cx.run_until_parked();
 525    let completion = fake_model.pending_completions().pop().unwrap();
 526    let message = completion.messages.last().unwrap();
 527    assert_eq!(
 528        message.content,
 529        vec![language_model::MessageContent::ToolResult(
 530            LanguageModelToolResult {
 531                tool_use_id: "tool_id_4".into(),
 532                tool_name: ToolRequiringPermission::name().into(),
 533                is_error: false,
 534                content: "Allowed".into(),
 535                output: Some("Allowed".into())
 536            }
 537        )]
 538    );
 539}
 540
 541#[gpui::test]
 542async fn test_tool_hallucination(cx: &mut TestAppContext) {
 543    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 544    let fake_model = model.as_fake();
 545
 546    let mut events = thread
 547        .update(cx, |thread, cx| {
 548            thread.send(UserMessageId::new(), ["abc"], cx)
 549        })
 550        .unwrap();
 551    cx.run_until_parked();
 552    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 553        LanguageModelToolUse {
 554            id: "tool_id_1".into(),
 555            name: "nonexistent_tool".into(),
 556            raw_input: "{}".into(),
 557            input: json!({}),
 558            is_input_complete: true,
 559        },
 560    ));
 561    fake_model.end_last_completion_stream();
 562
 563    let tool_call = expect_tool_call(&mut events).await;
 564    assert_eq!(tool_call.title, "nonexistent_tool");
 565    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 566    let update = expect_tool_call_update_fields(&mut events).await;
 567    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 568}
 569
 570#[gpui::test]
 571async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 572    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 573    let fake_model = model.as_fake();
 574
 575    let events = thread
 576        .update(cx, |thread, cx| {
 577            thread.add_tool(EchoTool);
 578            thread.send(UserMessageId::new(), ["abc"], cx)
 579        })
 580        .unwrap();
 581    cx.run_until_parked();
 582    let tool_use = LanguageModelToolUse {
 583        id: "tool_id_1".into(),
 584        name: EchoTool::name().into(),
 585        raw_input: "{}".into(),
 586        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 587        is_input_complete: true,
 588    };
 589    fake_model
 590        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 591    fake_model.end_last_completion_stream();
 592
 593    cx.run_until_parked();
 594    let completion = fake_model.pending_completions().pop().unwrap();
 595    let tool_result = LanguageModelToolResult {
 596        tool_use_id: "tool_id_1".into(),
 597        tool_name: EchoTool::name().into(),
 598        is_error: false,
 599        content: "def".into(),
 600        output: Some("def".into()),
 601    };
 602    assert_eq!(
 603        completion.messages[1..],
 604        vec![
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec!["abc".into()],
 608                cache: false
 609            },
 610            LanguageModelRequestMessage {
 611                role: Role::Assistant,
 612                content: vec![MessageContent::ToolUse(tool_use.clone())],
 613                cache: false
 614            },
 615            LanguageModelRequestMessage {
 616                role: Role::User,
 617                content: vec![MessageContent::ToolResult(tool_result.clone())],
 618                cache: true
 619            },
 620        ]
 621    );
 622
 623    // Simulate reaching tool use limit.
 624    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 625        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 626    ));
 627    fake_model.end_last_completion_stream();
 628    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 629    assert!(
 630        last_event
 631            .unwrap_err()
 632            .is::<language_model::ToolUseLimitReachedError>()
 633    );
 634
 635    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 636    cx.run_until_parked();
 637    let completion = fake_model.pending_completions().pop().unwrap();
 638    assert_eq!(
 639        completion.messages[1..],
 640        vec![
 641            LanguageModelRequestMessage {
 642                role: Role::User,
 643                content: vec!["abc".into()],
 644                cache: false
 645            },
 646            LanguageModelRequestMessage {
 647                role: Role::Assistant,
 648                content: vec![MessageContent::ToolUse(tool_use)],
 649                cache: false
 650            },
 651            LanguageModelRequestMessage {
 652                role: Role::User,
 653                content: vec![MessageContent::ToolResult(tool_result)],
 654                cache: false
 655            },
 656            LanguageModelRequestMessage {
 657                role: Role::User,
 658                content: vec!["Continue where you left off".into()],
 659                cache: true
 660            }
 661        ]
 662    );
 663
 664    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 665    fake_model.end_last_completion_stream();
 666    events.collect::<Vec<_>>().await;
 667    thread.read_with(cx, |thread, _cx| {
 668        assert_eq!(
 669            thread.last_message().unwrap().to_markdown(),
 670            indoc! {"
 671                ## Assistant
 672
 673                Done
 674            "}
 675        )
 676    });
 677}
 678
 679#[gpui::test]
 680async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 681    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 682    let fake_model = model.as_fake();
 683
 684    let events = thread
 685        .update(cx, |thread, cx| {
 686            thread.add_tool(EchoTool);
 687            thread.send(UserMessageId::new(), ["abc"], cx)
 688        })
 689        .unwrap();
 690    cx.run_until_parked();
 691
 692    let tool_use = LanguageModelToolUse {
 693        id: "tool_id_1".into(),
 694        name: EchoTool::name().into(),
 695        raw_input: "{}".into(),
 696        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 697        is_input_complete: true,
 698    };
 699    let tool_result = LanguageModelToolResult {
 700        tool_use_id: "tool_id_1".into(),
 701        tool_name: EchoTool::name().into(),
 702        is_error: false,
 703        content: "def".into(),
 704        output: Some("def".into()),
 705    };
 706    fake_model
 707        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 708    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 709        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 710    ));
 711    fake_model.end_last_completion_stream();
 712    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 713    assert!(
 714        last_event
 715            .unwrap_err()
 716            .is::<language_model::ToolUseLimitReachedError>()
 717    );
 718
 719    thread
 720        .update(cx, |thread, cx| {
 721            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 722        })
 723        .unwrap();
 724    cx.run_until_parked();
 725    let completion = fake_model.pending_completions().pop().unwrap();
 726    assert_eq!(
 727        completion.messages[1..],
 728        vec![
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["abc".into()],
 732                cache: false
 733            },
 734            LanguageModelRequestMessage {
 735                role: Role::Assistant,
 736                content: vec![MessageContent::ToolUse(tool_use)],
 737                cache: false
 738            },
 739            LanguageModelRequestMessage {
 740                role: Role::User,
 741                content: vec![MessageContent::ToolResult(tool_result)],
 742                cache: false
 743            },
 744            LanguageModelRequestMessage {
 745                role: Role::User,
 746                content: vec!["ghi".into()],
 747                cache: true
 748            }
 749        ]
 750    );
 751}
 752
 753async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 754    let event = events
 755        .next()
 756        .await
 757        .expect("no tool call authorization event received")
 758        .unwrap();
 759    match event {
 760        ThreadEvent::ToolCall(tool_call) => tool_call,
 761        event => {
 762            panic!("Unexpected event {event:?}");
 763        }
 764    }
 765}
 766
 767async fn expect_tool_call_update_fields(
 768    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 769) -> acp::ToolCallUpdate {
 770    let event = events
 771        .next()
 772        .await
 773        .expect("no tool call authorization event received")
 774        .unwrap();
 775    match event {
 776        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 777        event => {
 778            panic!("Unexpected event {event:?}");
 779        }
 780    }
 781}
 782
 783async fn next_tool_call_authorization(
 784    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 785) -> ToolCallAuthorization {
 786    loop {
 787        let event = events
 788            .next()
 789            .await
 790            .expect("no tool call authorization event received")
 791            .unwrap();
 792        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 793            let permission_kinds = tool_call_authorization
 794                .options
 795                .iter()
 796                .map(|o| o.kind)
 797                .collect::<Vec<_>>();
 798            assert_eq!(
 799                permission_kinds,
 800                vec![
 801                    acp::PermissionOptionKind::AllowAlways,
 802                    acp::PermissionOptionKind::AllowOnce,
 803                    acp::PermissionOptionKind::RejectOnce,
 804                ]
 805            );
 806            return tool_call_authorization;
 807        }
 808    }
 809}
 810
 811#[gpui::test]
 812#[cfg_attr(not(feature = "e2e"), ignore)]
 813async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 814    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 815
 816    // Test concurrent tool calls with different delay times
 817    let events = thread
 818        .update(cx, |thread, cx| {
 819            thread.add_tool(DelayTool);
 820            thread.send(
 821                UserMessageId::new(),
 822                [
 823                    "Call the delay tool twice in the same message.",
 824                    "Once with 100ms. Once with 300ms.",
 825                    "When both timers are complete, describe the outputs.",
 826                ],
 827                cx,
 828            )
 829        })
 830        .unwrap()
 831        .collect()
 832        .await;
 833
 834    let stop_reasons = stop_events(events);
 835    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 836
 837    thread.update(cx, |thread, _cx| {
 838        let last_message = thread.last_message().unwrap();
 839        let agent_message = last_message.as_agent_message().unwrap();
 840        let text = agent_message
 841            .content
 842            .iter()
 843            .filter_map(|content| {
 844                if let AgentMessageContent::Text(text) = content {
 845                    Some(text.as_str())
 846                } else {
 847                    None
 848                }
 849            })
 850            .collect::<String>();
 851
 852        assert!(text.contains("Ding"));
 853    });
 854}
 855
 856#[gpui::test]
 857async fn test_profiles(cx: &mut TestAppContext) {
 858    let ThreadTest {
 859        model, thread, fs, ..
 860    } = setup(cx, TestModel::Fake).await;
 861    let fake_model = model.as_fake();
 862
 863    thread.update(cx, |thread, _cx| {
 864        thread.add_tool(DelayTool);
 865        thread.add_tool(EchoTool);
 866        thread.add_tool(InfiniteTool);
 867    });
 868
 869    // Override profiles and wait for settings to be loaded.
 870    fs.insert_file(
 871        paths::settings_file(),
 872        json!({
 873            "agent": {
 874                "profiles": {
 875                    "test-1": {
 876                        "name": "Test Profile 1",
 877                        "tools": {
 878                            EchoTool::name(): true,
 879                            DelayTool::name(): true,
 880                        }
 881                    },
 882                    "test-2": {
 883                        "name": "Test Profile 2",
 884                        "tools": {
 885                            InfiniteTool::name(): true,
 886                        }
 887                    }
 888                }
 889            }
 890        })
 891        .to_string()
 892        .into_bytes(),
 893    )
 894    .await;
 895    cx.run_until_parked();
 896
 897    // Test that test-1 profile (default) has echo and delay tools
 898    thread
 899        .update(cx, |thread, cx| {
 900            thread.set_profile(AgentProfileId("test-1".into()));
 901            thread.send(UserMessageId::new(), ["test"], cx)
 902        })
 903        .unwrap();
 904    cx.run_until_parked();
 905
 906    let mut pending_completions = fake_model.pending_completions();
 907    assert_eq!(pending_completions.len(), 1);
 908    let completion = pending_completions.pop().unwrap();
 909    let tool_names: Vec<String> = completion
 910        .tools
 911        .iter()
 912        .map(|tool| tool.name.clone())
 913        .collect();
 914    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 915    fake_model.end_last_completion_stream();
 916
 917    // Switch to test-2 profile, and verify that it has only the infinite tool.
 918    thread
 919        .update(cx, |thread, cx| {
 920            thread.set_profile(AgentProfileId("test-2".into()));
 921            thread.send(UserMessageId::new(), ["test2"], cx)
 922        })
 923        .unwrap();
 924    cx.run_until_parked();
 925    let mut pending_completions = fake_model.pending_completions();
 926    assert_eq!(pending_completions.len(), 1);
 927    let completion = pending_completions.pop().unwrap();
 928    let tool_names: Vec<String> = completion
 929        .tools
 930        .iter()
 931        .map(|tool| tool.name.clone())
 932        .collect();
 933    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 934}
 935
 936#[gpui::test]
 937async fn test_mcp_tools(cx: &mut TestAppContext) {
 938    let ThreadTest {
 939        model,
 940        thread,
 941        context_server_store,
 942        fs,
 943        ..
 944    } = setup(cx, TestModel::Fake).await;
 945    let fake_model = model.as_fake();
 946
 947    // Override profiles and wait for settings to be loaded.
 948    fs.insert_file(
 949        paths::settings_file(),
 950        json!({
 951            "agent": {
 952                "always_allow_tool_actions": true,
 953                "profiles": {
 954                    "test": {
 955                        "name": "Test Profile",
 956                        "enable_all_context_servers": true,
 957                        "tools": {
 958                            EchoTool::name(): true,
 959                        }
 960                    },
 961                }
 962            }
 963        })
 964        .to_string()
 965        .into_bytes(),
 966    )
 967    .await;
 968    cx.run_until_parked();
 969    thread.update(cx, |thread, _| {
 970        thread.set_profile(AgentProfileId("test".into()))
 971    });
 972
 973    let mut mcp_tool_calls = setup_context_server(
 974        "test_server",
 975        vec![context_server::types::Tool {
 976            name: "echo".into(),
 977            description: None,
 978            input_schema: serde_json::to_value(
 979                EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
 980            )
 981            .unwrap(),
 982            output_schema: None,
 983            annotations: None,
 984        }],
 985        &context_server_store,
 986        cx,
 987    );
 988
 989    let events = thread.update(cx, |thread, cx| {
 990        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
 991    });
 992    cx.run_until_parked();
 993
 994    // Simulate the model calling the MCP tool.
 995    let completion = fake_model.pending_completions().pop().unwrap();
 996    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
 997    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 998        LanguageModelToolUse {
 999            id: "tool_1".into(),
1000            name: "echo".into(),
1001            raw_input: json!({"text": "test"}).to_string(),
1002            input: json!({"text": "test"}),
1003            is_input_complete: true,
1004        },
1005    ));
1006    fake_model.end_last_completion_stream();
1007    cx.run_until_parked();
1008
1009    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1010    assert_eq!(tool_call_params.name, "echo");
1011    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1012    tool_call_response
1013        .send(context_server::types::CallToolResponse {
1014            content: vec![context_server::types::ToolResponseContent::Text {
1015                text: "test".into(),
1016            }],
1017            is_error: None,
1018            meta: None,
1019            structured_content: None,
1020        })
1021        .unwrap();
1022    cx.run_until_parked();
1023
1024    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1025    fake_model.send_last_completion_stream_text_chunk("Done!");
1026    fake_model.end_last_completion_stream();
1027    events.collect::<Vec<_>>().await;
1028
1029    // Send again after adding the echo tool, ensuring the name collision is resolved.
1030    let events = thread.update(cx, |thread, cx| {
1031        thread.add_tool(EchoTool);
1032        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1033    });
1034    cx.run_until_parked();
1035    let completion = fake_model.pending_completions().pop().unwrap();
1036    assert_eq!(
1037        tool_names_for_completion(&completion),
1038        vec!["echo", "test_server_echo"]
1039    );
1040    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1041        LanguageModelToolUse {
1042            id: "tool_2".into(),
1043            name: "test_server_echo".into(),
1044            raw_input: json!({"text": "mcp"}).to_string(),
1045            input: json!({"text": "mcp"}),
1046            is_input_complete: true,
1047        },
1048    ));
1049    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1050        LanguageModelToolUse {
1051            id: "tool_3".into(),
1052            name: "echo".into(),
1053            raw_input: json!({"text": "native"}).to_string(),
1054            input: json!({"text": "native"}),
1055            is_input_complete: true,
1056        },
1057    ));
1058    fake_model.end_last_completion_stream();
1059    cx.run_until_parked();
1060
1061    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1062    assert_eq!(tool_call_params.name, "echo");
1063    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1064    tool_call_response
1065        .send(context_server::types::CallToolResponse {
1066            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1067            is_error: None,
1068            meta: None,
1069            structured_content: None,
1070        })
1071        .unwrap();
1072    cx.run_until_parked();
1073
1074    // Ensure the tool results were inserted with the correct names.
1075    let completion = fake_model.pending_completions().pop().unwrap();
1076    assert_eq!(
1077        completion.messages.last().unwrap().content,
1078        vec![
1079            MessageContent::ToolResult(LanguageModelToolResult {
1080                tool_use_id: "tool_3".into(),
1081                tool_name: "echo".into(),
1082                is_error: false,
1083                content: "native".into(),
1084                output: Some("native".into()),
1085            },),
1086            MessageContent::ToolResult(LanguageModelToolResult {
1087                tool_use_id: "tool_2".into(),
1088                tool_name: "test_server_echo".into(),
1089                is_error: false,
1090                content: "mcp".into(),
1091                output: Some("mcp".into()),
1092            },),
1093        ]
1094    );
1095    fake_model.end_last_completion_stream();
1096    events.collect::<Vec<_>>().await;
1097}
1098
1099#[gpui::test]
1100async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1101    let ThreadTest {
1102        model,
1103        thread,
1104        context_server_store,
1105        fs,
1106        ..
1107    } = setup(cx, TestModel::Fake).await;
1108    let fake_model = model.as_fake();
1109
1110    // Set up a profile with all tools enabled
1111    fs.insert_file(
1112        paths::settings_file(),
1113        json!({
1114            "agent": {
1115                "profiles": {
1116                    "test": {
1117                        "name": "Test Profile",
1118                        "enable_all_context_servers": true,
1119                        "tools": {
1120                            EchoTool::name(): true,
1121                            DelayTool::name(): true,
1122                            WordListTool::name(): true,
1123                            ToolRequiringPermission::name(): true,
1124                            InfiniteTool::name(): true,
1125                        }
1126                    },
1127                }
1128            }
1129        })
1130        .to_string()
1131        .into_bytes(),
1132    )
1133    .await;
1134    cx.run_until_parked();
1135
1136    thread.update(cx, |thread, _| {
1137        thread.set_profile(AgentProfileId("test".into()));
1138        thread.add_tool(EchoTool);
1139        thread.add_tool(DelayTool);
1140        thread.add_tool(WordListTool);
1141        thread.add_tool(ToolRequiringPermission);
1142        thread.add_tool(InfiniteTool);
1143    });
1144
1145    // Set up multiple context servers with some overlapping tool names
1146    let _server1_calls = setup_context_server(
1147        "xxx",
1148        vec![
1149            context_server::types::Tool {
1150                name: "echo".into(), // Conflicts with native EchoTool
1151                description: None,
1152                input_schema: serde_json::to_value(
1153                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1154                )
1155                .unwrap(),
1156                output_schema: None,
1157                annotations: None,
1158            },
1159            context_server::types::Tool {
1160                name: "unique_tool_1".into(),
1161                description: None,
1162                input_schema: json!({"type": "object", "properties": {}}),
1163                output_schema: None,
1164                annotations: None,
1165            },
1166        ],
1167        &context_server_store,
1168        cx,
1169    );
1170
1171    let _server2_calls = setup_context_server(
1172        "yyy",
1173        vec![
1174            context_server::types::Tool {
1175                name: "echo".into(), // Also conflicts with native EchoTool
1176                description: None,
1177                input_schema: serde_json::to_value(
1178                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1179                )
1180                .unwrap(),
1181                output_schema: None,
1182                annotations: None,
1183            },
1184            context_server::types::Tool {
1185                name: "unique_tool_2".into(),
1186                description: None,
1187                input_schema: json!({"type": "object", "properties": {}}),
1188                output_schema: None,
1189                annotations: None,
1190            },
1191            context_server::types::Tool {
1192                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1193                description: None,
1194                input_schema: json!({"type": "object", "properties": {}}),
1195                output_schema: None,
1196                annotations: None,
1197            },
1198            context_server::types::Tool {
1199                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1200                description: None,
1201                input_schema: json!({"type": "object", "properties": {}}),
1202                output_schema: None,
1203                annotations: None,
1204            },
1205        ],
1206        &context_server_store,
1207        cx,
1208    );
1209    let _server3_calls = setup_context_server(
1210        "zzz",
1211        vec![
1212            context_server::types::Tool {
1213                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1214                description: None,
1215                input_schema: json!({"type": "object", "properties": {}}),
1216                output_schema: None,
1217                annotations: None,
1218            },
1219            context_server::types::Tool {
1220                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1221                description: None,
1222                input_schema: json!({"type": "object", "properties": {}}),
1223                output_schema: None,
1224                annotations: None,
1225            },
1226            context_server::types::Tool {
1227                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1228                description: None,
1229                input_schema: json!({"type": "object", "properties": {}}),
1230                output_schema: None,
1231                annotations: None,
1232            },
1233        ],
1234        &context_server_store,
1235        cx,
1236    );
1237
1238    thread
1239        .update(cx, |thread, cx| {
1240            thread.send(UserMessageId::new(), ["Go"], cx)
1241        })
1242        .unwrap();
1243    cx.run_until_parked();
1244    let completion = fake_model.pending_completions().pop().unwrap();
1245    assert_eq!(
1246        tool_names_for_completion(&completion),
1247        vec![
1248            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1249            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1250            "delay",
1251            "echo",
1252            "infinite",
1253            "tool_requiring_permission",
1254            "unique_tool_1",
1255            "unique_tool_2",
1256            "word_list",
1257            "xxx_echo",
1258            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1259            "yyy_echo",
1260            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1261        ]
1262    );
1263}
1264
1265#[gpui::test]
1266#[cfg_attr(not(feature = "e2e"), ignore)]
1267async fn test_cancellation(cx: &mut TestAppContext) {
1268    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1269
1270    let mut events = thread
1271        .update(cx, |thread, cx| {
1272            thread.add_tool(InfiniteTool);
1273            thread.add_tool(EchoTool);
1274            thread.send(
1275                UserMessageId::new(),
1276                ["Call the echo tool, then call the infinite tool, then explain their output"],
1277                cx,
1278            )
1279        })
1280        .unwrap();
1281
1282    // Wait until both tools are called.
1283    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1284    let mut echo_id = None;
1285    let mut echo_completed = false;
1286    while let Some(event) = events.next().await {
1287        match event.unwrap() {
1288            ThreadEvent::ToolCall(tool_call) => {
1289                assert_eq!(tool_call.title, expected_tools.remove(0));
1290                if tool_call.title == "Echo" {
1291                    echo_id = Some(tool_call.id);
1292                }
1293            }
1294            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1295                acp::ToolCallUpdate {
1296                    id,
1297                    fields:
1298                        acp::ToolCallUpdateFields {
1299                            status: Some(acp::ToolCallStatus::Completed),
1300                            ..
1301                        },
1302                },
1303            )) if Some(&id) == echo_id.as_ref() => {
1304                echo_completed = true;
1305            }
1306            _ => {}
1307        }
1308
1309        if expected_tools.is_empty() && echo_completed {
1310            break;
1311        }
1312    }
1313
1314    // Cancel the current send and ensure that the event stream is closed, even
1315    // if one of the tools is still running.
1316    thread.update(cx, |thread, cx| thread.cancel(cx));
1317    let events = events.collect::<Vec<_>>().await;
1318    let last_event = events.last();
1319    assert!(
1320        matches!(
1321            last_event,
1322            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1323        ),
1324        "unexpected event {last_event:?}"
1325    );
1326
1327    // Ensure we can still send a new message after cancellation.
1328    let events = thread
1329        .update(cx, |thread, cx| {
1330            thread.send(
1331                UserMessageId::new(),
1332                ["Testing: reply with 'Hello' then stop."],
1333                cx,
1334            )
1335        })
1336        .unwrap()
1337        .collect::<Vec<_>>()
1338        .await;
1339    thread.update(cx, |thread, _cx| {
1340        let message = thread.last_message().unwrap();
1341        let agent_message = message.as_agent_message().unwrap();
1342        assert_eq!(
1343            agent_message.content,
1344            vec![AgentMessageContent::Text("Hello".to_string())]
1345        );
1346    });
1347    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1348}
1349
1350#[gpui::test]
1351async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1352    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1353    let fake_model = model.as_fake();
1354
1355    let events_1 = thread
1356        .update(cx, |thread, cx| {
1357            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1358        })
1359        .unwrap();
1360    cx.run_until_parked();
1361    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1362    cx.run_until_parked();
1363
1364    let events_2 = thread
1365        .update(cx, |thread, cx| {
1366            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1367        })
1368        .unwrap();
1369    cx.run_until_parked();
1370    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1371    fake_model
1372        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1373    fake_model.end_last_completion_stream();
1374
1375    let events_1 = events_1.collect::<Vec<_>>().await;
1376    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1377    let events_2 = events_2.collect::<Vec<_>>().await;
1378    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1379}
1380
1381#[gpui::test]
1382async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1383    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1384    let fake_model = model.as_fake();
1385
1386    let events_1 = thread
1387        .update(cx, |thread, cx| {
1388            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1389        })
1390        .unwrap();
1391    cx.run_until_parked();
1392    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1393    fake_model
1394        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1395    fake_model.end_last_completion_stream();
1396    let events_1 = events_1.collect::<Vec<_>>().await;
1397
1398    let events_2 = thread
1399        .update(cx, |thread, cx| {
1400            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1401        })
1402        .unwrap();
1403    cx.run_until_parked();
1404    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1405    fake_model
1406        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1407    fake_model.end_last_completion_stream();
1408    let events_2 = events_2.collect::<Vec<_>>().await;
1409
1410    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1411    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1412}
1413
1414#[gpui::test]
1415async fn test_refusal(cx: &mut TestAppContext) {
1416    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1417    let fake_model = model.as_fake();
1418
1419    let events = thread
1420        .update(cx, |thread, cx| {
1421            thread.send(UserMessageId::new(), ["Hello"], cx)
1422        })
1423        .unwrap();
1424    cx.run_until_parked();
1425    thread.read_with(cx, |thread, _| {
1426        assert_eq!(
1427            thread.to_markdown(),
1428            indoc! {"
1429                ## User
1430
1431                Hello
1432            "}
1433        );
1434    });
1435
1436    fake_model.send_last_completion_stream_text_chunk("Hey!");
1437    cx.run_until_parked();
1438    thread.read_with(cx, |thread, _| {
1439        assert_eq!(
1440            thread.to_markdown(),
1441            indoc! {"
1442                ## User
1443
1444                Hello
1445
1446                ## Assistant
1447
1448                Hey!
1449            "}
1450        );
1451    });
1452
1453    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1454    fake_model
1455        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1456    let events = events.collect::<Vec<_>>().await;
1457    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1458    thread.read_with(cx, |thread, _| {
1459        assert_eq!(thread.to_markdown(), "");
1460    });
1461}
1462
1463#[gpui::test]
1464async fn test_truncate_first_message(cx: &mut TestAppContext) {
1465    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1466    let fake_model = model.as_fake();
1467
1468    let message_id = UserMessageId::new();
1469    thread
1470        .update(cx, |thread, cx| {
1471            thread.send(message_id.clone(), ["Hello"], cx)
1472        })
1473        .unwrap();
1474    cx.run_until_parked();
1475    thread.read_with(cx, |thread, _| {
1476        assert_eq!(
1477            thread.to_markdown(),
1478            indoc! {"
1479                ## User
1480
1481                Hello
1482            "}
1483        );
1484        assert_eq!(thread.latest_token_usage(), None);
1485    });
1486
1487    fake_model.send_last_completion_stream_text_chunk("Hey!");
1488    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1489        language_model::TokenUsage {
1490            input_tokens: 32_000,
1491            output_tokens: 16_000,
1492            cache_creation_input_tokens: 0,
1493            cache_read_input_tokens: 0,
1494        },
1495    ));
1496    cx.run_until_parked();
1497    thread.read_with(cx, |thread, _| {
1498        assert_eq!(
1499            thread.to_markdown(),
1500            indoc! {"
1501                ## User
1502
1503                Hello
1504
1505                ## Assistant
1506
1507                Hey!
1508            "}
1509        );
1510        assert_eq!(
1511            thread.latest_token_usage(),
1512            Some(acp_thread::TokenUsage {
1513                used_tokens: 32_000 + 16_000,
1514                max_tokens: 1_000_000,
1515            })
1516        );
1517    });
1518
1519    thread
1520        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1521        .unwrap();
1522    cx.run_until_parked();
1523    thread.read_with(cx, |thread, _| {
1524        assert_eq!(thread.to_markdown(), "");
1525        assert_eq!(thread.latest_token_usage(), None);
1526    });
1527
1528    // Ensure we can still send a new message after truncation.
1529    thread
1530        .update(cx, |thread, cx| {
1531            thread.send(UserMessageId::new(), ["Hi"], cx)
1532        })
1533        .unwrap();
1534    thread.update(cx, |thread, _cx| {
1535        assert_eq!(
1536            thread.to_markdown(),
1537            indoc! {"
1538                ## User
1539
1540                Hi
1541            "}
1542        );
1543    });
1544    cx.run_until_parked();
1545    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1546    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1547        language_model::TokenUsage {
1548            input_tokens: 40_000,
1549            output_tokens: 20_000,
1550            cache_creation_input_tokens: 0,
1551            cache_read_input_tokens: 0,
1552        },
1553    ));
1554    cx.run_until_parked();
1555    thread.read_with(cx, |thread, _| {
1556        assert_eq!(
1557            thread.to_markdown(),
1558            indoc! {"
1559                ## User
1560
1561                Hi
1562
1563                ## Assistant
1564
1565                Ahoy!
1566            "}
1567        );
1568
1569        assert_eq!(
1570            thread.latest_token_usage(),
1571            Some(acp_thread::TokenUsage {
1572                used_tokens: 40_000 + 20_000,
1573                max_tokens: 1_000_000,
1574            })
1575        );
1576    });
1577}
1578
1579#[gpui::test]
1580async fn test_truncate_second_message(cx: &mut TestAppContext) {
1581    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1582    let fake_model = model.as_fake();
1583
1584    thread
1585        .update(cx, |thread, cx| {
1586            thread.send(UserMessageId::new(), ["Message 1"], cx)
1587        })
1588        .unwrap();
1589    cx.run_until_parked();
1590    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1591    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1592        language_model::TokenUsage {
1593            input_tokens: 32_000,
1594            output_tokens: 16_000,
1595            cache_creation_input_tokens: 0,
1596            cache_read_input_tokens: 0,
1597        },
1598    ));
1599    fake_model.end_last_completion_stream();
1600    cx.run_until_parked();
1601
1602    let assert_first_message_state = |cx: &mut TestAppContext| {
1603        thread.clone().read_with(cx, |thread, _| {
1604            assert_eq!(
1605                thread.to_markdown(),
1606                indoc! {"
1607                    ## User
1608
1609                    Message 1
1610
1611                    ## Assistant
1612
1613                    Message 1 response
1614                "}
1615            );
1616
1617            assert_eq!(
1618                thread.latest_token_usage(),
1619                Some(acp_thread::TokenUsage {
1620                    used_tokens: 32_000 + 16_000,
1621                    max_tokens: 1_000_000,
1622                })
1623            );
1624        });
1625    };
1626
1627    assert_first_message_state(cx);
1628
1629    let second_message_id = UserMessageId::new();
1630    thread
1631        .update(cx, |thread, cx| {
1632            thread.send(second_message_id.clone(), ["Message 2"], cx)
1633        })
1634        .unwrap();
1635    cx.run_until_parked();
1636
1637    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1638    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1639        language_model::TokenUsage {
1640            input_tokens: 40_000,
1641            output_tokens: 20_000,
1642            cache_creation_input_tokens: 0,
1643            cache_read_input_tokens: 0,
1644        },
1645    ));
1646    fake_model.end_last_completion_stream();
1647    cx.run_until_parked();
1648
1649    thread.read_with(cx, |thread, _| {
1650        assert_eq!(
1651            thread.to_markdown(),
1652            indoc! {"
1653                ## User
1654
1655                Message 1
1656
1657                ## Assistant
1658
1659                Message 1 response
1660
1661                ## User
1662
1663                Message 2
1664
1665                ## Assistant
1666
1667                Message 2 response
1668            "}
1669        );
1670
1671        assert_eq!(
1672            thread.latest_token_usage(),
1673            Some(acp_thread::TokenUsage {
1674                used_tokens: 40_000 + 20_000,
1675                max_tokens: 1_000_000,
1676            })
1677        );
1678    });
1679
1680    thread
1681        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1682        .unwrap();
1683    cx.run_until_parked();
1684
1685    assert_first_message_state(cx);
1686}
1687
1688#[gpui::test]
1689async fn test_title_generation(cx: &mut TestAppContext) {
1690    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1691    let fake_model = model.as_fake();
1692
1693    let summary_model = Arc::new(FakeLanguageModel::default());
1694    thread.update(cx, |thread, cx| {
1695        thread.set_summarization_model(Some(summary_model.clone()), cx)
1696    });
1697
1698    let send = thread
1699        .update(cx, |thread, cx| {
1700            thread.send(UserMessageId::new(), ["Hello"], cx)
1701        })
1702        .unwrap();
1703    cx.run_until_parked();
1704
1705    fake_model.send_last_completion_stream_text_chunk("Hey!");
1706    fake_model.end_last_completion_stream();
1707    cx.run_until_parked();
1708    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1709
1710    // Ensure the summary model has been invoked to generate a title.
1711    summary_model.send_last_completion_stream_text_chunk("Hello ");
1712    summary_model.send_last_completion_stream_text_chunk("world\nG");
1713    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1714    summary_model.end_last_completion_stream();
1715    send.collect::<Vec<_>>().await;
1716    cx.run_until_parked();
1717    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1718
1719    // Send another message, ensuring no title is generated this time.
1720    let send = thread
1721        .update(cx, |thread, cx| {
1722            thread.send(UserMessageId::new(), ["Hello again"], cx)
1723        })
1724        .unwrap();
1725    cx.run_until_parked();
1726    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1727    fake_model.end_last_completion_stream();
1728    cx.run_until_parked();
1729    assert_eq!(summary_model.pending_completions(), Vec::new());
1730    send.collect::<Vec<_>>().await;
1731    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1732}
1733
1734#[gpui::test]
1735async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1736    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1737    let fake_model = model.as_fake();
1738
1739    let _events = thread
1740        .update(cx, |thread, cx| {
1741            thread.add_tool(ToolRequiringPermission);
1742            thread.add_tool(EchoTool);
1743            thread.send(UserMessageId::new(), ["Hey!"], cx)
1744        })
1745        .unwrap();
1746    cx.run_until_parked();
1747
1748    let permission_tool_use = LanguageModelToolUse {
1749        id: "tool_id_1".into(),
1750        name: ToolRequiringPermission::name().into(),
1751        raw_input: "{}".into(),
1752        input: json!({}),
1753        is_input_complete: true,
1754    };
1755    let echo_tool_use = LanguageModelToolUse {
1756        id: "tool_id_2".into(),
1757        name: EchoTool::name().into(),
1758        raw_input: json!({"text": "test"}).to_string(),
1759        input: json!({"text": "test"}),
1760        is_input_complete: true,
1761    };
1762    fake_model.send_last_completion_stream_text_chunk("Hi!");
1763    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1764        permission_tool_use,
1765    ));
1766    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1767        echo_tool_use.clone(),
1768    ));
1769    fake_model.end_last_completion_stream();
1770    cx.run_until_parked();
1771
1772    // Ensure pending tools are skipped when building a request.
1773    let request = thread
1774        .read_with(cx, |thread, cx| {
1775            thread.build_completion_request(CompletionIntent::EditFile, cx)
1776        })
1777        .unwrap();
1778    assert_eq!(
1779        request.messages[1..],
1780        vec![
1781            LanguageModelRequestMessage {
1782                role: Role::User,
1783                content: vec!["Hey!".into()],
1784                cache: true
1785            },
1786            LanguageModelRequestMessage {
1787                role: Role::Assistant,
1788                content: vec![
1789                    MessageContent::Text("Hi!".into()),
1790                    MessageContent::ToolUse(echo_tool_use.clone())
1791                ],
1792                cache: false
1793            },
1794            LanguageModelRequestMessage {
1795                role: Role::User,
1796                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1797                    tool_use_id: echo_tool_use.id.clone(),
1798                    tool_name: echo_tool_use.name,
1799                    is_error: false,
1800                    content: "test".into(),
1801                    output: Some("test".into())
1802                })],
1803                cache: false
1804            },
1805        ],
1806    );
1807}
1808
1809#[gpui::test]
1810async fn test_agent_connection(cx: &mut TestAppContext) {
1811    cx.update(settings::init);
1812    let templates = Templates::new();
1813
1814    // Initialize language model system with test provider
1815    cx.update(|cx| {
1816        gpui_tokio::init(cx);
1817        client::init_settings(cx);
1818
1819        let http_client = FakeHttpClient::with_404_response();
1820        let clock = Arc::new(clock::FakeSystemClock::new());
1821        let client = Client::new(clock, http_client, cx);
1822        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1823        language_model::init(client.clone(), cx);
1824        language_models::init(user_store, client.clone(), cx);
1825        Project::init_settings(cx);
1826        LanguageModelRegistry::test(cx);
1827        agent_settings::init(cx);
1828    });
1829    cx.executor().forbid_parking();
1830
1831    // Create a project for new_thread
1832    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1833    fake_fs.insert_tree(path!("/test"), json!({})).await;
1834    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1835    let cwd = Path::new("/test");
1836    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1837    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1838
1839    // Create agent and connection
1840    let agent = NativeAgent::new(
1841        project.clone(),
1842        history_store,
1843        templates.clone(),
1844        None,
1845        fake_fs.clone(),
1846        &mut cx.to_async(),
1847    )
1848    .await
1849    .unwrap();
1850    let connection = NativeAgentConnection(agent.clone());
1851
1852    // Test model_selector returns Some
1853    let selector_opt = connection.model_selector();
1854    assert!(
1855        selector_opt.is_some(),
1856        "agent2 should always support ModelSelector"
1857    );
1858    let selector = selector_opt.unwrap();
1859
1860    // Test list_models
1861    let listed_models = cx
1862        .update(|cx| selector.list_models(cx))
1863        .await
1864        .expect("list_models should succeed");
1865    let AgentModelList::Grouped(listed_models) = listed_models else {
1866        panic!("Unexpected model list type");
1867    };
1868    assert!(!listed_models.is_empty(), "should have at least one model");
1869    assert_eq!(
1870        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1871        "fake/fake"
1872    );
1873
1874    // Create a thread using new_thread
1875    let connection_rc = Rc::new(connection.clone());
1876    let acp_thread = cx
1877        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1878        .await
1879        .expect("new_thread should succeed");
1880
1881    // Get the session_id from the AcpThread
1882    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1883
1884    // Test selected_model returns the default
1885    let model = cx
1886        .update(|cx| selector.selected_model(&session_id, cx))
1887        .await
1888        .expect("selected_model should succeed");
1889    let model = cx
1890        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1891        .unwrap();
1892    let model = model.as_fake();
1893    assert_eq!(model.id().0, "fake", "should return default model");
1894
1895    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1896    cx.run_until_parked();
1897    model.send_last_completion_stream_text_chunk("def");
1898    cx.run_until_parked();
1899    acp_thread.read_with(cx, |thread, cx| {
1900        assert_eq!(
1901            thread.to_markdown(cx),
1902            indoc! {"
1903                ## User
1904
1905                abc
1906
1907                ## Assistant
1908
1909                def
1910
1911            "}
1912        )
1913    });
1914
1915    // Test cancel
1916    cx.update(|cx| connection.cancel(&session_id, cx));
1917    request.await.expect("prompt should fail gracefully");
1918
1919    // Ensure that dropping the ACP thread causes the native thread to be
1920    // dropped as well.
1921    cx.update(|_| drop(acp_thread));
1922    let result = cx
1923        .update(|cx| {
1924            connection.prompt(
1925                Some(acp_thread::UserMessageId::new()),
1926                acp::PromptRequest {
1927                    session_id: session_id.clone(),
1928                    prompt: vec!["ghi".into()],
1929                },
1930                cx,
1931            )
1932        })
1933        .await;
1934    assert_eq!(
1935        result.as_ref().unwrap_err().to_string(),
1936        "Session not found",
1937        "unexpected result: {:?}",
1938        result
1939    );
1940}
1941
1942#[gpui::test]
1943async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1944    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1945    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1946    let fake_model = model.as_fake();
1947
1948    let mut events = thread
1949        .update(cx, |thread, cx| {
1950            thread.send(UserMessageId::new(), ["Think"], cx)
1951        })
1952        .unwrap();
1953    cx.run_until_parked();
1954
1955    // Simulate streaming partial input.
1956    let input = json!({});
1957    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1958        LanguageModelToolUse {
1959            id: "1".into(),
1960            name: ThinkingTool::name().into(),
1961            raw_input: input.to_string(),
1962            input,
1963            is_input_complete: false,
1964        },
1965    ));
1966
1967    // Input streaming completed
1968    let input = json!({ "content": "Thinking hard!" });
1969    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1970        LanguageModelToolUse {
1971            id: "1".into(),
1972            name: "thinking".into(),
1973            raw_input: input.to_string(),
1974            input,
1975            is_input_complete: true,
1976        },
1977    ));
1978    fake_model.end_last_completion_stream();
1979    cx.run_until_parked();
1980
1981    let tool_call = expect_tool_call(&mut events).await;
1982    assert_eq!(
1983        tool_call,
1984        acp::ToolCall {
1985            id: acp::ToolCallId("1".into()),
1986            title: "Thinking".into(),
1987            kind: acp::ToolKind::Think,
1988            status: acp::ToolCallStatus::Pending,
1989            content: vec![],
1990            locations: vec![],
1991            raw_input: Some(json!({})),
1992            raw_output: None,
1993        }
1994    );
1995    let update = expect_tool_call_update_fields(&mut events).await;
1996    assert_eq!(
1997        update,
1998        acp::ToolCallUpdate {
1999            id: acp::ToolCallId("1".into()),
2000            fields: acp::ToolCallUpdateFields {
2001                title: Some("Thinking".into()),
2002                kind: Some(acp::ToolKind::Think),
2003                raw_input: Some(json!({ "content": "Thinking hard!" })),
2004                ..Default::default()
2005            },
2006        }
2007    );
2008    let update = expect_tool_call_update_fields(&mut events).await;
2009    assert_eq!(
2010        update,
2011        acp::ToolCallUpdate {
2012            id: acp::ToolCallId("1".into()),
2013            fields: acp::ToolCallUpdateFields {
2014                status: Some(acp::ToolCallStatus::InProgress),
2015                ..Default::default()
2016            },
2017        }
2018    );
2019    let update = expect_tool_call_update_fields(&mut events).await;
2020    assert_eq!(
2021        update,
2022        acp::ToolCallUpdate {
2023            id: acp::ToolCallId("1".into()),
2024            fields: acp::ToolCallUpdateFields {
2025                content: Some(vec!["Thinking hard!".into()]),
2026                ..Default::default()
2027            },
2028        }
2029    );
2030    let update = expect_tool_call_update_fields(&mut events).await;
2031    assert_eq!(
2032        update,
2033        acp::ToolCallUpdate {
2034            id: acp::ToolCallId("1".into()),
2035            fields: acp::ToolCallUpdateFields {
2036                status: Some(acp::ToolCallStatus::Completed),
2037                raw_output: Some("Finished thinking.".into()),
2038                ..Default::default()
2039            },
2040        }
2041    );
2042}
2043
2044#[gpui::test]
2045async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2046    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2047    let fake_model = model.as_fake();
2048
2049    let mut events = thread
2050        .update(cx, |thread, cx| {
2051            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2052            thread.send(UserMessageId::new(), ["Hello!"], cx)
2053        })
2054        .unwrap();
2055    cx.run_until_parked();
2056
2057    fake_model.send_last_completion_stream_text_chunk("Hey!");
2058    fake_model.end_last_completion_stream();
2059
2060    let mut retry_events = Vec::new();
2061    while let Some(Ok(event)) = events.next().await {
2062        match event {
2063            ThreadEvent::Retry(retry_status) => {
2064                retry_events.push(retry_status);
2065            }
2066            ThreadEvent::Stop(..) => break,
2067            _ => {}
2068        }
2069    }
2070
2071    assert_eq!(retry_events.len(), 0);
2072    thread.read_with(cx, |thread, _cx| {
2073        assert_eq!(
2074            thread.to_markdown(),
2075            indoc! {"
2076                ## User
2077
2078                Hello!
2079
2080                ## Assistant
2081
2082                Hey!
2083            "}
2084        )
2085    });
2086}
2087
2088#[gpui::test]
2089async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2090    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2091    let fake_model = model.as_fake();
2092
2093    let mut events = thread
2094        .update(cx, |thread, cx| {
2095            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2096            thread.send(UserMessageId::new(), ["Hello!"], cx)
2097        })
2098        .unwrap();
2099    cx.run_until_parked();
2100
2101    fake_model.send_last_completion_stream_text_chunk("Hey,");
2102    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2103        provider: LanguageModelProviderName::new("Anthropic"),
2104        retry_after: Some(Duration::from_secs(3)),
2105    });
2106    fake_model.end_last_completion_stream();
2107
2108    cx.executor().advance_clock(Duration::from_secs(3));
2109    cx.run_until_parked();
2110
2111    fake_model.send_last_completion_stream_text_chunk("there!");
2112    fake_model.end_last_completion_stream();
2113    cx.run_until_parked();
2114
2115    let mut retry_events = Vec::new();
2116    while let Some(Ok(event)) = events.next().await {
2117        match event {
2118            ThreadEvent::Retry(retry_status) => {
2119                retry_events.push(retry_status);
2120            }
2121            ThreadEvent::Stop(..) => break,
2122            _ => {}
2123        }
2124    }
2125
2126    assert_eq!(retry_events.len(), 1);
2127    assert!(matches!(
2128        retry_events[0],
2129        acp_thread::RetryStatus { attempt: 1, .. }
2130    ));
2131    thread.read_with(cx, |thread, _cx| {
2132        assert_eq!(
2133            thread.to_markdown(),
2134            indoc! {"
2135                ## User
2136
2137                Hello!
2138
2139                ## Assistant
2140
2141                Hey,
2142
2143                [resume]
2144
2145                ## Assistant
2146
2147                there!
2148            "}
2149        )
2150    });
2151}
2152
2153#[gpui::test]
2154async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2155    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2156    let fake_model = model.as_fake();
2157
2158    let events = thread
2159        .update(cx, |thread, cx| {
2160            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2161            thread.add_tool(EchoTool);
2162            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2163        })
2164        .unwrap();
2165    cx.run_until_parked();
2166
2167    let tool_use_1 = LanguageModelToolUse {
2168        id: "tool_1".into(),
2169        name: EchoTool::name().into(),
2170        raw_input: json!({"text": "test"}).to_string(),
2171        input: json!({"text": "test"}),
2172        is_input_complete: true,
2173    };
2174    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2175        tool_use_1.clone(),
2176    ));
2177    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2178        provider: LanguageModelProviderName::new("Anthropic"),
2179        retry_after: Some(Duration::from_secs(3)),
2180    });
2181    fake_model.end_last_completion_stream();
2182
2183    cx.executor().advance_clock(Duration::from_secs(3));
2184    let completion = fake_model.pending_completions().pop().unwrap();
2185    assert_eq!(
2186        completion.messages[1..],
2187        vec![
2188            LanguageModelRequestMessage {
2189                role: Role::User,
2190                content: vec!["Call the echo tool!".into()],
2191                cache: false
2192            },
2193            LanguageModelRequestMessage {
2194                role: Role::Assistant,
2195                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2196                cache: false
2197            },
2198            LanguageModelRequestMessage {
2199                role: Role::User,
2200                content: vec![language_model::MessageContent::ToolResult(
2201                    LanguageModelToolResult {
2202                        tool_use_id: tool_use_1.id.clone(),
2203                        tool_name: tool_use_1.name.clone(),
2204                        is_error: false,
2205                        content: "test".into(),
2206                        output: Some("test".into())
2207                    }
2208                )],
2209                cache: true
2210            },
2211        ]
2212    );
2213
2214    fake_model.send_last_completion_stream_text_chunk("Done");
2215    fake_model.end_last_completion_stream();
2216    cx.run_until_parked();
2217    events.collect::<Vec<_>>().await;
2218    thread.read_with(cx, |thread, _cx| {
2219        assert_eq!(
2220            thread.last_message(),
2221            Some(Message::Agent(AgentMessage {
2222                content: vec![AgentMessageContent::Text("Done".into())],
2223                tool_results: IndexMap::default()
2224            }))
2225        );
2226    })
2227}
2228
2229#[gpui::test]
2230async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2231    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2232    let fake_model = model.as_fake();
2233
2234    let mut events = thread
2235        .update(cx, |thread, cx| {
2236            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2237            thread.send(UserMessageId::new(), ["Hello!"], cx)
2238        })
2239        .unwrap();
2240    cx.run_until_parked();
2241
2242    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2243        fake_model.send_last_completion_stream_error(
2244            LanguageModelCompletionError::ServerOverloaded {
2245                provider: LanguageModelProviderName::new("Anthropic"),
2246                retry_after: Some(Duration::from_secs(3)),
2247            },
2248        );
2249        fake_model.end_last_completion_stream();
2250        cx.executor().advance_clock(Duration::from_secs(3));
2251        cx.run_until_parked();
2252    }
2253
2254    let mut errors = Vec::new();
2255    let mut retry_events = Vec::new();
2256    while let Some(event) = events.next().await {
2257        match event {
2258            Ok(ThreadEvent::Retry(retry_status)) => {
2259                retry_events.push(retry_status);
2260            }
2261            Ok(ThreadEvent::Stop(..)) => break,
2262            Err(error) => errors.push(error),
2263            _ => {}
2264        }
2265    }
2266
2267    assert_eq!(
2268        retry_events.len(),
2269        crate::thread::MAX_RETRY_ATTEMPTS as usize
2270    );
2271    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2272        assert_eq!(retry_events[i].attempt, i + 1);
2273    }
2274    assert_eq!(errors.len(), 1);
2275    let error = errors[0]
2276        .downcast_ref::<LanguageModelCompletionError>()
2277        .unwrap();
2278    assert!(matches!(
2279        error,
2280        LanguageModelCompletionError::ServerOverloaded { .. }
2281    ));
2282}
2283
2284/// Filters out the stop events for asserting against in tests
2285fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2286    result_events
2287        .into_iter()
2288        .filter_map(|event| match event.unwrap() {
2289            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2290            _ => None,
2291        })
2292        .collect()
2293}
2294
2295struct ThreadTest {
2296    model: Arc<dyn LanguageModel>,
2297    thread: Entity<Thread>,
2298    project_context: Entity<ProjectContext>,
2299    context_server_store: Entity<ContextServerStore>,
2300    fs: Arc<FakeFs>,
2301}
2302
2303enum TestModel {
2304    Sonnet4,
2305    Fake,
2306}
2307
2308impl TestModel {
2309    fn id(&self) -> LanguageModelId {
2310        match self {
2311            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2312            TestModel::Fake => unreachable!(),
2313        }
2314    }
2315}
2316
2317async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2318    cx.executor().allow_parking();
2319
2320    let fs = FakeFs::new(cx.background_executor.clone());
2321    fs.create_dir(paths::settings_file().parent().unwrap())
2322        .await
2323        .unwrap();
2324    fs.insert_file(
2325        paths::settings_file(),
2326        json!({
2327            "agent": {
2328                "default_profile": "test-profile",
2329                "profiles": {
2330                    "test-profile": {
2331                        "name": "Test Profile",
2332                        "tools": {
2333                            EchoTool::name(): true,
2334                            DelayTool::name(): true,
2335                            WordListTool::name(): true,
2336                            ToolRequiringPermission::name(): true,
2337                            InfiniteTool::name(): true,
2338                            ThinkingTool::name(): true,
2339                        }
2340                    }
2341                }
2342            }
2343        })
2344        .to_string()
2345        .into_bytes(),
2346    )
2347    .await;
2348
2349    cx.update(|cx| {
2350        settings::init(cx);
2351        Project::init_settings(cx);
2352        agent_settings::init(cx);
2353
2354        match model {
2355            TestModel::Fake => {}
2356            TestModel::Sonnet4 => {
2357                gpui_tokio::init(cx);
2358                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2359                cx.set_http_client(Arc::new(http_client));
2360                client::init_settings(cx);
2361                let client = Client::production(cx);
2362                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2363                language_model::init(client.clone(), cx);
2364                language_models::init(user_store, client.clone(), cx);
2365            }
2366        };
2367
2368        watch_settings(fs.clone(), cx);
2369    });
2370
2371    let templates = Templates::new();
2372
2373    fs.insert_tree(path!("/test"), json!({})).await;
2374    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2375
2376    let model = cx
2377        .update(|cx| {
2378            if let TestModel::Fake = model {
2379                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2380            } else {
2381                let model_id = model.id();
2382                let models = LanguageModelRegistry::read_global(cx);
2383                let model = models
2384                    .available_models(cx)
2385                    .find(|model| model.id() == model_id)
2386                    .unwrap();
2387
2388                let provider = models.provider(&model.provider_id()).unwrap();
2389                let authenticated = provider.authenticate(cx);
2390
2391                cx.spawn(async move |_cx| {
2392                    authenticated.await.unwrap();
2393                    model
2394                })
2395            }
2396        })
2397        .await;
2398
2399    let project_context = cx.new(|_cx| ProjectContext::default());
2400    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2401    let context_server_registry =
2402        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2403    let thread = cx.new(|cx| {
2404        Thread::new(
2405            project,
2406            project_context.clone(),
2407            context_server_registry,
2408            templates,
2409            Some(model.clone()),
2410            cx,
2411        )
2412    });
2413    ThreadTest {
2414        model,
2415        thread,
2416        project_context,
2417        context_server_store,
2418        fs,
2419    }
2420}
2421
2422#[cfg(test)]
2423#[ctor::ctor]
2424fn init_logger() {
2425    if std::env::var("RUST_LOG").is_ok() {
2426        env_logger::init();
2427    }
2428}
2429
2430fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2431    let fs = fs.clone();
2432    cx.spawn({
2433        async move |cx| {
2434            let mut new_settings_content_rx = settings::watch_config_file(
2435                cx.background_executor(),
2436                fs,
2437                paths::settings_file().clone(),
2438            );
2439
2440            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2441                cx.update(|cx| {
2442                    SettingsStore::update_global(cx, |settings, cx| {
2443                        settings.set_user_settings(&new_settings_content, cx)
2444                    })
2445                })
2446                .ok();
2447            }
2448        }
2449    })
2450    .detach();
2451}
2452
2453fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2454    completion
2455        .tools
2456        .iter()
2457        .map(|tool| tool.name.clone())
2458        .collect()
2459}
2460
2461fn setup_context_server(
2462    name: &'static str,
2463    tools: Vec<context_server::types::Tool>,
2464    context_server_store: &Entity<ContextServerStore>,
2465    cx: &mut TestAppContext,
2466) -> mpsc::UnboundedReceiver<(
2467    context_server::types::CallToolParams,
2468    oneshot::Sender<context_server::types::CallToolResponse>,
2469)> {
2470    cx.update(|cx| {
2471        let mut settings = ProjectSettings::get_global(cx).clone();
2472        settings.context_servers.insert(
2473            name.into(),
2474            project::project_settings::ContextServerSettings::Custom {
2475                enabled: true,
2476                command: ContextServerCommand {
2477                    path: "somebinary".into(),
2478                    args: Vec::new(),
2479                    env: None,
2480                    timeout: None,
2481                },
2482            },
2483        );
2484        ProjectSettings::override_global(settings, cx);
2485    });
2486
2487    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2488    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2489        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2490            context_server::types::InitializeResponse {
2491                protocol_version: context_server::types::ProtocolVersion(
2492                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2493                ),
2494                server_info: context_server::types::Implementation {
2495                    name: name.into(),
2496                    version: "1.0.0".to_string(),
2497                },
2498                capabilities: context_server::types::ServerCapabilities {
2499                    tools: Some(context_server::types::ToolsCapabilities {
2500                        list_changed: Some(true),
2501                    }),
2502                    ..Default::default()
2503                },
2504                meta: None,
2505            }
2506        })
2507        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2508            let tools = tools.clone();
2509            async move {
2510                context_server::types::ListToolsResponse {
2511                    tools,
2512                    next_cursor: None,
2513                    meta: None,
2514                }
2515            }
2516        })
2517        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2518            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2519            async move {
2520                let (response_tx, response_rx) = oneshot::channel();
2521                mcp_tool_calls_tx
2522                    .unbounded_send((params, response_tx))
2523                    .unwrap();
2524                response_rx.await.unwrap()
2525            }
2526        });
2527    context_server_store.update(cx, |store, cx| {
2528        store.start_server(
2529            Arc::new(ContextServer::new(
2530                ContextServerId(name.into()),
2531                Arc::new(fake_transport),
2532            )),
2533            cx,
2534        );
2535    });
2536    cx.run_until_parked();
2537    mcp_tool_calls_rx
2538}