mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_prompt_caching(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    // Send initial user message and verify it's cached
 169    thread
 170        .update(cx, |thread, cx| {
 171            thread.send(UserMessageId::new(), ["Message 1"], cx)
 172        })
 173        .unwrap();
 174    cx.run_until_parked();
 175
 176    let completion = fake_model.pending_completions().pop().unwrap();
 177    assert_eq!(
 178        completion.messages[1..],
 179        vec![LanguageModelRequestMessage {
 180            role: Role::User,
 181            content: vec!["Message 1".into()],
 182            cache: true
 183        }]
 184    );
 185    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 186        "Response to Message 1".into(),
 187    ));
 188    fake_model.end_last_completion_stream();
 189    cx.run_until_parked();
 190
 191    // Send another user message and verify only the latest is cached
 192    thread
 193        .update(cx, |thread, cx| {
 194            thread.send(UserMessageId::new(), ["Message 2"], cx)
 195        })
 196        .unwrap();
 197    cx.run_until_parked();
 198
 199    let completion = fake_model.pending_completions().pop().unwrap();
 200    assert_eq!(
 201        completion.messages[1..],
 202        vec![
 203            LanguageModelRequestMessage {
 204                role: Role::User,
 205                content: vec!["Message 1".into()],
 206                cache: false
 207            },
 208            LanguageModelRequestMessage {
 209                role: Role::Assistant,
 210                content: vec!["Response to Message 1".into()],
 211                cache: false
 212            },
 213            LanguageModelRequestMessage {
 214                role: Role::User,
 215                content: vec!["Message 2".into()],
 216                cache: true
 217            }
 218        ]
 219    );
 220    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 221        "Response to Message 2".into(),
 222    ));
 223    fake_model.end_last_completion_stream();
 224    cx.run_until_parked();
 225
 226    // Simulate a tool call and verify that the latest tool result is cached
 227    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let tool_use = LanguageModelToolUse {
 236        id: "tool_1".into(),
 237        name: EchoTool::name().into(),
 238        raw_input: json!({"text": "test"}).to_string(),
 239        input: json!({"text": "test"}),
 240        is_input_complete: true,
 241    };
 242    fake_model
 243        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 244    fake_model.end_last_completion_stream();
 245    cx.run_until_parked();
 246
 247    let completion = fake_model.pending_completions().pop().unwrap();
 248    let tool_result = LanguageModelToolResult {
 249        tool_use_id: "tool_1".into(),
 250        tool_name: EchoTool::name().into(),
 251        is_error: false,
 252        content: "test".into(),
 253        output: Some("test".into()),
 254    };
 255    assert_eq!(
 256        completion.messages[1..],
 257        vec![
 258            LanguageModelRequestMessage {
 259                role: Role::User,
 260                content: vec!["Message 1".into()],
 261                cache: false
 262            },
 263            LanguageModelRequestMessage {
 264                role: Role::Assistant,
 265                content: vec!["Response to Message 1".into()],
 266                cache: false
 267            },
 268            LanguageModelRequestMessage {
 269                role: Role::User,
 270                content: vec!["Message 2".into()],
 271                cache: false
 272            },
 273            LanguageModelRequestMessage {
 274                role: Role::Assistant,
 275                content: vec!["Response to Message 2".into()],
 276                cache: false
 277            },
 278            LanguageModelRequestMessage {
 279                role: Role::User,
 280                content: vec!["Use the echo tool".into()],
 281                cache: false
 282            },
 283            LanguageModelRequestMessage {
 284                role: Role::Assistant,
 285                content: vec![MessageContent::ToolUse(tool_use)],
 286                cache: false
 287            },
 288            LanguageModelRequestMessage {
 289                role: Role::User,
 290                content: vec![MessageContent::ToolResult(tool_result)],
 291                cache: true
 292            }
 293        ]
 294    );
 295}
 296
 297#[gpui::test]
 298#[cfg_attr(not(feature = "e2e"), ignore)]
 299async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 300    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 301
 302    // Test a tool call that's likely to complete *before* streaming stops.
 303    let events = thread
 304        .update(cx, |thread, cx| {
 305            thread.add_tool(EchoTool);
 306            thread.send(
 307                UserMessageId::new(),
 308                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 309                cx,
 310            )
 311        })
 312        .unwrap()
 313        .collect()
 314        .await;
 315    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 316
 317    // Test a tool calls that's likely to complete *after* streaming stops.
 318    let events = thread
 319        .update(cx, |thread, cx| {
 320            thread.remove_tool(&EchoTool::name());
 321            thread.add_tool(DelayTool);
 322            thread.send(
 323                UserMessageId::new(),
 324                [
 325                    "Now call the delay tool with 200ms.",
 326                    "When the timer goes off, then you echo the output of the tool.",
 327                ],
 328                cx,
 329            )
 330        })
 331        .unwrap()
 332        .collect()
 333        .await;
 334    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 335    thread.update(cx, |thread, _cx| {
 336        assert!(
 337            thread
 338                .last_message()
 339                .unwrap()
 340                .as_agent_message()
 341                .unwrap()
 342                .content
 343                .iter()
 344                .any(|content| {
 345                    if let AgentMessageContent::Text(text) = content {
 346                        text.contains("Ding")
 347                    } else {
 348                        false
 349                    }
 350                }),
 351            "{}",
 352            thread.to_markdown()
 353        );
 354    });
 355}
 356
 357#[gpui::test]
 358#[cfg_attr(not(feature = "e2e"), ignore)]
 359async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 360    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 361
 362    // Test a tool call that's likely to complete *before* streaming stops.
 363    let mut events = thread
 364        .update(cx, |thread, cx| {
 365            thread.add_tool(WordListTool);
 366            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 367        })
 368        .unwrap();
 369
 370    let mut saw_partial_tool_use = false;
 371    while let Some(event) = events.next().await {
 372        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 373            thread.update(cx, |thread, _cx| {
 374                // Look for a tool use in the thread's last message
 375                let message = thread.last_message().unwrap();
 376                let agent_message = message.as_agent_message().unwrap();
 377                let last_content = agent_message.content.last().unwrap();
 378                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 379                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 380                    if tool_call.status == acp::ToolCallStatus::Pending {
 381                        if !last_tool_use.is_input_complete
 382                            && last_tool_use.input.get("g").is_none()
 383                        {
 384                            saw_partial_tool_use = true;
 385                        }
 386                    } else {
 387                        last_tool_use
 388                            .input
 389                            .get("a")
 390                            .expect("'a' has streamed because input is now complete");
 391                        last_tool_use
 392                            .input
 393                            .get("g")
 394                            .expect("'g' has streamed because input is now complete");
 395                    }
 396                } else {
 397                    panic!("last content should be a tool use");
 398                }
 399            });
 400        }
 401    }
 402
 403    assert!(
 404        saw_partial_tool_use,
 405        "should see at least one partially streamed tool use in the history"
 406    );
 407}
 408
 409#[gpui::test]
 410async fn test_tool_authorization(cx: &mut TestAppContext) {
 411    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 412    let fake_model = model.as_fake();
 413
 414    let mut events = thread
 415        .update(cx, |thread, cx| {
 416            thread.add_tool(ToolRequiringPermission);
 417            thread.send(UserMessageId::new(), ["abc"], cx)
 418        })
 419        .unwrap();
 420    cx.run_until_parked();
 421    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 422        LanguageModelToolUse {
 423            id: "tool_id_1".into(),
 424            name: ToolRequiringPermission::name().into(),
 425            raw_input: "{}".into(),
 426            input: json!({}),
 427            is_input_complete: true,
 428        },
 429    ));
 430    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 431        LanguageModelToolUse {
 432            id: "tool_id_2".into(),
 433            name: ToolRequiringPermission::name().into(),
 434            raw_input: "{}".into(),
 435            input: json!({}),
 436            is_input_complete: true,
 437        },
 438    ));
 439    fake_model.end_last_completion_stream();
 440    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 441    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 442
 443    // Approve the first
 444    tool_call_auth_1
 445        .response
 446        .send(tool_call_auth_1.options[1].id.clone())
 447        .unwrap();
 448    cx.run_until_parked();
 449
 450    // Reject the second
 451    tool_call_auth_2
 452        .response
 453        .send(tool_call_auth_1.options[2].id.clone())
 454        .unwrap();
 455    cx.run_until_parked();
 456
 457    let completion = fake_model.pending_completions().pop().unwrap();
 458    let message = completion.messages.last().unwrap();
 459    assert_eq!(
 460        message.content,
 461        vec![
 462            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 463                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 464                tool_name: ToolRequiringPermission::name().into(),
 465                is_error: false,
 466                content: "Allowed".into(),
 467                output: Some("Allowed".into())
 468            }),
 469            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 470                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 471                tool_name: ToolRequiringPermission::name().into(),
 472                is_error: true,
 473                content: "Permission to run tool denied by user".into(),
 474                output: Some("Permission to run tool denied by user".into())
 475            })
 476        ]
 477    );
 478
 479    // Simulate yet another tool call.
 480    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 481        LanguageModelToolUse {
 482            id: "tool_id_3".into(),
 483            name: ToolRequiringPermission::name().into(),
 484            raw_input: "{}".into(),
 485            input: json!({}),
 486            is_input_complete: true,
 487        },
 488    ));
 489    fake_model.end_last_completion_stream();
 490
 491    // Respond by always allowing tools.
 492    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 493    tool_call_auth_3
 494        .response
 495        .send(tool_call_auth_3.options[0].id.clone())
 496        .unwrap();
 497    cx.run_until_parked();
 498    let completion = fake_model.pending_completions().pop().unwrap();
 499    let message = completion.messages.last().unwrap();
 500    assert_eq!(
 501        message.content,
 502        vec![language_model::MessageContent::ToolResult(
 503            LanguageModelToolResult {
 504                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 505                tool_name: ToolRequiringPermission::name().into(),
 506                is_error: false,
 507                content: "Allowed".into(),
 508                output: Some("Allowed".into())
 509            }
 510        )]
 511    );
 512
 513    // Simulate a final tool call, ensuring we don't trigger authorization.
 514    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 515        LanguageModelToolUse {
 516            id: "tool_id_4".into(),
 517            name: ToolRequiringPermission::name().into(),
 518            raw_input: "{}".into(),
 519            input: json!({}),
 520            is_input_complete: true,
 521        },
 522    ));
 523    fake_model.end_last_completion_stream();
 524    cx.run_until_parked();
 525    let completion = fake_model.pending_completions().pop().unwrap();
 526    let message = completion.messages.last().unwrap();
 527    assert_eq!(
 528        message.content,
 529        vec![language_model::MessageContent::ToolResult(
 530            LanguageModelToolResult {
 531                tool_use_id: "tool_id_4".into(),
 532                tool_name: ToolRequiringPermission::name().into(),
 533                is_error: false,
 534                content: "Allowed".into(),
 535                output: Some("Allowed".into())
 536            }
 537        )]
 538    );
 539}
 540
 541#[gpui::test]
 542async fn test_tool_hallucination(cx: &mut TestAppContext) {
 543    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 544    let fake_model = model.as_fake();
 545
 546    let mut events = thread
 547        .update(cx, |thread, cx| {
 548            thread.send(UserMessageId::new(), ["abc"], cx)
 549        })
 550        .unwrap();
 551    cx.run_until_parked();
 552    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 553        LanguageModelToolUse {
 554            id: "tool_id_1".into(),
 555            name: "nonexistent_tool".into(),
 556            raw_input: "{}".into(),
 557            input: json!({}),
 558            is_input_complete: true,
 559        },
 560    ));
 561    fake_model.end_last_completion_stream();
 562
 563    let tool_call = expect_tool_call(&mut events).await;
 564    assert_eq!(tool_call.title, "nonexistent_tool");
 565    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 566    let update = expect_tool_call_update_fields(&mut events).await;
 567    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 568}
 569
 570#[gpui::test]
 571async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 572    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 573    let fake_model = model.as_fake();
 574
 575    let events = thread
 576        .update(cx, |thread, cx| {
 577            thread.add_tool(EchoTool);
 578            thread.send(UserMessageId::new(), ["abc"], cx)
 579        })
 580        .unwrap();
 581    cx.run_until_parked();
 582    let tool_use = LanguageModelToolUse {
 583        id: "tool_id_1".into(),
 584        name: EchoTool::name().into(),
 585        raw_input: "{}".into(),
 586        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 587        is_input_complete: true,
 588    };
 589    fake_model
 590        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 591    fake_model.end_last_completion_stream();
 592
 593    cx.run_until_parked();
 594    let completion = fake_model.pending_completions().pop().unwrap();
 595    let tool_result = LanguageModelToolResult {
 596        tool_use_id: "tool_id_1".into(),
 597        tool_name: EchoTool::name().into(),
 598        is_error: false,
 599        content: "def".into(),
 600        output: Some("def".into()),
 601    };
 602    assert_eq!(
 603        completion.messages[1..],
 604        vec![
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec!["abc".into()],
 608                cache: false
 609            },
 610            LanguageModelRequestMessage {
 611                role: Role::Assistant,
 612                content: vec![MessageContent::ToolUse(tool_use.clone())],
 613                cache: false
 614            },
 615            LanguageModelRequestMessage {
 616                role: Role::User,
 617                content: vec![MessageContent::ToolResult(tool_result.clone())],
 618                cache: true
 619            },
 620        ]
 621    );
 622
 623    // Simulate reaching tool use limit.
 624    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 625        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 626    ));
 627    fake_model.end_last_completion_stream();
 628    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 629    assert!(
 630        last_event
 631            .unwrap_err()
 632            .is::<language_model::ToolUseLimitReachedError>()
 633    );
 634
 635    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 636    cx.run_until_parked();
 637    let completion = fake_model.pending_completions().pop().unwrap();
 638    assert_eq!(
 639        completion.messages[1..],
 640        vec![
 641            LanguageModelRequestMessage {
 642                role: Role::User,
 643                content: vec!["abc".into()],
 644                cache: false
 645            },
 646            LanguageModelRequestMessage {
 647                role: Role::Assistant,
 648                content: vec![MessageContent::ToolUse(tool_use)],
 649                cache: false
 650            },
 651            LanguageModelRequestMessage {
 652                role: Role::User,
 653                content: vec![MessageContent::ToolResult(tool_result)],
 654                cache: false
 655            },
 656            LanguageModelRequestMessage {
 657                role: Role::User,
 658                content: vec!["Continue where you left off".into()],
 659                cache: true
 660            }
 661        ]
 662    );
 663
 664    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 665    fake_model.end_last_completion_stream();
 666    events.collect::<Vec<_>>().await;
 667    thread.read_with(cx, |thread, _cx| {
 668        assert_eq!(
 669            thread.last_message().unwrap().to_markdown(),
 670            indoc! {"
 671                ## Assistant
 672
 673                Done
 674            "}
 675        )
 676    });
 677}
 678
 679#[gpui::test]
 680async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 681    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 682    let fake_model = model.as_fake();
 683
 684    let events = thread
 685        .update(cx, |thread, cx| {
 686            thread.add_tool(EchoTool);
 687            thread.send(UserMessageId::new(), ["abc"], cx)
 688        })
 689        .unwrap();
 690    cx.run_until_parked();
 691
 692    let tool_use = LanguageModelToolUse {
 693        id: "tool_id_1".into(),
 694        name: EchoTool::name().into(),
 695        raw_input: "{}".into(),
 696        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 697        is_input_complete: true,
 698    };
 699    let tool_result = LanguageModelToolResult {
 700        tool_use_id: "tool_id_1".into(),
 701        tool_name: EchoTool::name().into(),
 702        is_error: false,
 703        content: "def".into(),
 704        output: Some("def".into()),
 705    };
 706    fake_model
 707        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 708    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 709        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 710    ));
 711    fake_model.end_last_completion_stream();
 712    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 713    assert!(
 714        last_event
 715            .unwrap_err()
 716            .is::<language_model::ToolUseLimitReachedError>()
 717    );
 718
 719    thread
 720        .update(cx, |thread, cx| {
 721            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 722        })
 723        .unwrap();
 724    cx.run_until_parked();
 725    let completion = fake_model.pending_completions().pop().unwrap();
 726    assert_eq!(
 727        completion.messages[1..],
 728        vec![
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["abc".into()],
 732                cache: false
 733            },
 734            LanguageModelRequestMessage {
 735                role: Role::Assistant,
 736                content: vec![MessageContent::ToolUse(tool_use)],
 737                cache: false
 738            },
 739            LanguageModelRequestMessage {
 740                role: Role::User,
 741                content: vec![MessageContent::ToolResult(tool_result)],
 742                cache: false
 743            },
 744            LanguageModelRequestMessage {
 745                role: Role::User,
 746                content: vec!["ghi".into()],
 747                cache: true
 748            }
 749        ]
 750    );
 751}
 752
 753async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 754    let event = events
 755        .next()
 756        .await
 757        .expect("no tool call authorization event received")
 758        .unwrap();
 759    match event {
 760        ThreadEvent::ToolCall(tool_call) => tool_call,
 761        event => {
 762            panic!("Unexpected event {event:?}");
 763        }
 764    }
 765}
 766
 767async fn expect_tool_call_update_fields(
 768    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 769) -> acp::ToolCallUpdate {
 770    let event = events
 771        .next()
 772        .await
 773        .expect("no tool call authorization event received")
 774        .unwrap();
 775    match event {
 776        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 777        event => {
 778            panic!("Unexpected event {event:?}");
 779        }
 780    }
 781}
 782
 783async fn next_tool_call_authorization(
 784    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 785) -> ToolCallAuthorization {
 786    loop {
 787        let event = events
 788            .next()
 789            .await
 790            .expect("no tool call authorization event received")
 791            .unwrap();
 792        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 793            let permission_kinds = tool_call_authorization
 794                .options
 795                .iter()
 796                .map(|o| o.kind)
 797                .collect::<Vec<_>>();
 798            assert_eq!(
 799                permission_kinds,
 800                vec![
 801                    acp::PermissionOptionKind::AllowAlways,
 802                    acp::PermissionOptionKind::AllowOnce,
 803                    acp::PermissionOptionKind::RejectOnce,
 804                ]
 805            );
 806            return tool_call_authorization;
 807        }
 808    }
 809}
 810
 811#[gpui::test]
 812#[cfg_attr(not(feature = "e2e"), ignore)]
 813async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 814    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 815
 816    // Test concurrent tool calls with different delay times
 817    let events = thread
 818        .update(cx, |thread, cx| {
 819            thread.add_tool(DelayTool);
 820            thread.send(
 821                UserMessageId::new(),
 822                [
 823                    "Call the delay tool twice in the same message.",
 824                    "Once with 100ms. Once with 300ms.",
 825                    "When both timers are complete, describe the outputs.",
 826                ],
 827                cx,
 828            )
 829        })
 830        .unwrap()
 831        .collect()
 832        .await;
 833
 834    let stop_reasons = stop_events(events);
 835    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 836
 837    thread.update(cx, |thread, _cx| {
 838        let last_message = thread.last_message().unwrap();
 839        let agent_message = last_message.as_agent_message().unwrap();
 840        let text = agent_message
 841            .content
 842            .iter()
 843            .filter_map(|content| {
 844                if let AgentMessageContent::Text(text) = content {
 845                    Some(text.as_str())
 846                } else {
 847                    None
 848                }
 849            })
 850            .collect::<String>();
 851
 852        assert!(text.contains("Ding"));
 853    });
 854}
 855
 856#[gpui::test]
 857async fn test_profiles(cx: &mut TestAppContext) {
 858    let ThreadTest {
 859        model, thread, fs, ..
 860    } = setup(cx, TestModel::Fake).await;
 861    let fake_model = model.as_fake();
 862
 863    thread.update(cx, |thread, _cx| {
 864        thread.add_tool(DelayTool);
 865        thread.add_tool(EchoTool);
 866        thread.add_tool(InfiniteTool);
 867    });
 868
 869    // Override profiles and wait for settings to be loaded.
 870    fs.insert_file(
 871        paths::settings_file(),
 872        json!({
 873            "agent": {
 874                "profiles": {
 875                    "test-1": {
 876                        "name": "Test Profile 1",
 877                        "tools": {
 878                            EchoTool::name(): true,
 879                            DelayTool::name(): true,
 880                        }
 881                    },
 882                    "test-2": {
 883                        "name": "Test Profile 2",
 884                        "tools": {
 885                            InfiniteTool::name(): true,
 886                        }
 887                    }
 888                }
 889            }
 890        })
 891        .to_string()
 892        .into_bytes(),
 893    )
 894    .await;
 895    cx.run_until_parked();
 896
 897    // Test that test-1 profile (default) has echo and delay tools
 898    thread
 899        .update(cx, |thread, cx| {
 900            thread.set_profile(AgentProfileId("test-1".into()));
 901            thread.send(UserMessageId::new(), ["test"], cx)
 902        })
 903        .unwrap();
 904    cx.run_until_parked();
 905
 906    let mut pending_completions = fake_model.pending_completions();
 907    assert_eq!(pending_completions.len(), 1);
 908    let completion = pending_completions.pop().unwrap();
 909    let tool_names: Vec<String> = completion
 910        .tools
 911        .iter()
 912        .map(|tool| tool.name.clone())
 913        .collect();
 914    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 915    fake_model.end_last_completion_stream();
 916
 917    // Switch to test-2 profile, and verify that it has only the infinite tool.
 918    thread
 919        .update(cx, |thread, cx| {
 920            thread.set_profile(AgentProfileId("test-2".into()));
 921            thread.send(UserMessageId::new(), ["test2"], cx)
 922        })
 923        .unwrap();
 924    cx.run_until_parked();
 925    let mut pending_completions = fake_model.pending_completions();
 926    assert_eq!(pending_completions.len(), 1);
 927    let completion = pending_completions.pop().unwrap();
 928    let tool_names: Vec<String> = completion
 929        .tools
 930        .iter()
 931        .map(|tool| tool.name.clone())
 932        .collect();
 933    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 934}
 935
 936#[gpui::test]
 937async fn test_mcp_tools(cx: &mut TestAppContext) {
 938    let ThreadTest {
 939        model,
 940        thread,
 941        context_server_store,
 942        fs,
 943        ..
 944    } = setup(cx, TestModel::Fake).await;
 945    let fake_model = model.as_fake();
 946
 947    // Override profiles and wait for settings to be loaded.
 948    fs.insert_file(
 949        paths::settings_file(),
 950        json!({
 951            "agent": {
 952                "always_allow_tool_actions": true,
 953                "profiles": {
 954                    "test": {
 955                        "name": "Test Profile",
 956                        "enable_all_context_servers": true,
 957                        "tools": {
 958                            EchoTool::name(): true,
 959                        }
 960                    },
 961                }
 962            }
 963        })
 964        .to_string()
 965        .into_bytes(),
 966    )
 967    .await;
 968    cx.run_until_parked();
 969    thread.update(cx, |thread, _| {
 970        thread.set_profile(AgentProfileId("test".into()))
 971    });
 972
 973    let mut mcp_tool_calls = setup_context_server(
 974        "test_server",
 975        vec![context_server::types::Tool {
 976            name: "echo".into(),
 977            description: None,
 978            input_schema: serde_json::to_value(EchoTool::input_schema(
 979                LanguageModelToolSchemaFormat::JsonSchema,
 980            ))
 981            .unwrap(),
 982            output_schema: None,
 983            annotations: None,
 984        }],
 985        &context_server_store,
 986        cx,
 987    );
 988
 989    let events = thread.update(cx, |thread, cx| {
 990        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
 991    });
 992    cx.run_until_parked();
 993
 994    // Simulate the model calling the MCP tool.
 995    let completion = fake_model.pending_completions().pop().unwrap();
 996    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
 997    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 998        LanguageModelToolUse {
 999            id: "tool_1".into(),
1000            name: "echo".into(),
1001            raw_input: json!({"text": "test"}).to_string(),
1002            input: json!({"text": "test"}),
1003            is_input_complete: true,
1004        },
1005    ));
1006    fake_model.end_last_completion_stream();
1007    cx.run_until_parked();
1008
1009    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1010    assert_eq!(tool_call_params.name, "echo");
1011    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1012    tool_call_response
1013        .send(context_server::types::CallToolResponse {
1014            content: vec![context_server::types::ToolResponseContent::Text {
1015                text: "test".into(),
1016            }],
1017            is_error: None,
1018            meta: None,
1019            structured_content: None,
1020        })
1021        .unwrap();
1022    cx.run_until_parked();
1023
1024    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1025    fake_model.send_last_completion_stream_text_chunk("Done!");
1026    fake_model.end_last_completion_stream();
1027    events.collect::<Vec<_>>().await;
1028
1029    // Send again after adding the echo tool, ensuring the name collision is resolved.
1030    let events = thread.update(cx, |thread, cx| {
1031        thread.add_tool(EchoTool);
1032        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1033    });
1034    cx.run_until_parked();
1035    let completion = fake_model.pending_completions().pop().unwrap();
1036    assert_eq!(
1037        tool_names_for_completion(&completion),
1038        vec!["echo", "test_server_echo"]
1039    );
1040    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1041        LanguageModelToolUse {
1042            id: "tool_2".into(),
1043            name: "test_server_echo".into(),
1044            raw_input: json!({"text": "mcp"}).to_string(),
1045            input: json!({"text": "mcp"}),
1046            is_input_complete: true,
1047        },
1048    ));
1049    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1050        LanguageModelToolUse {
1051            id: "tool_3".into(),
1052            name: "echo".into(),
1053            raw_input: json!({"text": "native"}).to_string(),
1054            input: json!({"text": "native"}),
1055            is_input_complete: true,
1056        },
1057    ));
1058    fake_model.end_last_completion_stream();
1059    cx.run_until_parked();
1060
1061    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1062    assert_eq!(tool_call_params.name, "echo");
1063    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1064    tool_call_response
1065        .send(context_server::types::CallToolResponse {
1066            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1067            is_error: None,
1068            meta: None,
1069            structured_content: None,
1070        })
1071        .unwrap();
1072    cx.run_until_parked();
1073
1074    // Ensure the tool results were inserted with the correct names.
1075    let completion = fake_model.pending_completions().pop().unwrap();
1076    assert_eq!(
1077        completion.messages.last().unwrap().content,
1078        vec![
1079            MessageContent::ToolResult(LanguageModelToolResult {
1080                tool_use_id: "tool_3".into(),
1081                tool_name: "echo".into(),
1082                is_error: false,
1083                content: "native".into(),
1084                output: Some("native".into()),
1085            },),
1086            MessageContent::ToolResult(LanguageModelToolResult {
1087                tool_use_id: "tool_2".into(),
1088                tool_name: "test_server_echo".into(),
1089                is_error: false,
1090                content: "mcp".into(),
1091                output: Some("mcp".into()),
1092            },),
1093        ]
1094    );
1095    fake_model.end_last_completion_stream();
1096    events.collect::<Vec<_>>().await;
1097}
1098
1099#[gpui::test]
1100async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1101    let ThreadTest {
1102        model,
1103        thread,
1104        context_server_store,
1105        fs,
1106        ..
1107    } = setup(cx, TestModel::Fake).await;
1108    let fake_model = model.as_fake();
1109
1110    // Set up a profile with all tools enabled
1111    fs.insert_file(
1112        paths::settings_file(),
1113        json!({
1114            "agent": {
1115                "profiles": {
1116                    "test": {
1117                        "name": "Test Profile",
1118                        "enable_all_context_servers": true,
1119                        "tools": {
1120                            EchoTool::name(): true,
1121                            DelayTool::name(): true,
1122                            WordListTool::name(): true,
1123                            ToolRequiringPermission::name(): true,
1124                            InfiniteTool::name(): true,
1125                        }
1126                    },
1127                }
1128            }
1129        })
1130        .to_string()
1131        .into_bytes(),
1132    )
1133    .await;
1134    cx.run_until_parked();
1135
1136    thread.update(cx, |thread, _| {
1137        thread.set_profile(AgentProfileId("test".into()));
1138        thread.add_tool(EchoTool);
1139        thread.add_tool(DelayTool);
1140        thread.add_tool(WordListTool);
1141        thread.add_tool(ToolRequiringPermission);
1142        thread.add_tool(InfiniteTool);
1143    });
1144
1145    // Set up multiple context servers with some overlapping tool names
1146    let _server1_calls = setup_context_server(
1147        "xxx",
1148        vec![
1149            context_server::types::Tool {
1150                name: "echo".into(), // Conflicts with native EchoTool
1151                description: None,
1152                input_schema: serde_json::to_value(EchoTool::input_schema(
1153                    LanguageModelToolSchemaFormat::JsonSchema,
1154                ))
1155                .unwrap(),
1156                output_schema: None,
1157                annotations: None,
1158            },
1159            context_server::types::Tool {
1160                name: "unique_tool_1".into(),
1161                description: None,
1162                input_schema: json!({"type": "object", "properties": {}}),
1163                output_schema: None,
1164                annotations: None,
1165            },
1166        ],
1167        &context_server_store,
1168        cx,
1169    );
1170
1171    let _server2_calls = setup_context_server(
1172        "yyy",
1173        vec![
1174            context_server::types::Tool {
1175                name: "echo".into(), // Also conflicts with native EchoTool
1176                description: None,
1177                input_schema: serde_json::to_value(EchoTool::input_schema(
1178                    LanguageModelToolSchemaFormat::JsonSchema,
1179                ))
1180                .unwrap(),
1181                output_schema: None,
1182                annotations: None,
1183            },
1184            context_server::types::Tool {
1185                name: "unique_tool_2".into(),
1186                description: None,
1187                input_schema: json!({"type": "object", "properties": {}}),
1188                output_schema: None,
1189                annotations: None,
1190            },
1191            context_server::types::Tool {
1192                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1193                description: None,
1194                input_schema: json!({"type": "object", "properties": {}}),
1195                output_schema: None,
1196                annotations: None,
1197            },
1198            context_server::types::Tool {
1199                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1200                description: None,
1201                input_schema: json!({"type": "object", "properties": {}}),
1202                output_schema: None,
1203                annotations: None,
1204            },
1205        ],
1206        &context_server_store,
1207        cx,
1208    );
1209    let _server3_calls = setup_context_server(
1210        "zzz",
1211        vec![
1212            context_server::types::Tool {
1213                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1214                description: None,
1215                input_schema: json!({"type": "object", "properties": {}}),
1216                output_schema: None,
1217                annotations: None,
1218            },
1219            context_server::types::Tool {
1220                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1221                description: None,
1222                input_schema: json!({"type": "object", "properties": {}}),
1223                output_schema: None,
1224                annotations: None,
1225            },
1226            context_server::types::Tool {
1227                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1228                description: None,
1229                input_schema: json!({"type": "object", "properties": {}}),
1230                output_schema: None,
1231                annotations: None,
1232            },
1233        ],
1234        &context_server_store,
1235        cx,
1236    );
1237
1238    thread
1239        .update(cx, |thread, cx| {
1240            thread.send(UserMessageId::new(), ["Go"], cx)
1241        })
1242        .unwrap();
1243    cx.run_until_parked();
1244    let completion = fake_model.pending_completions().pop().unwrap();
1245    assert_eq!(
1246        tool_names_for_completion(&completion),
1247        vec![
1248            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1249            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1250            "delay",
1251            "echo",
1252            "infinite",
1253            "tool_requiring_permission",
1254            "unique_tool_1",
1255            "unique_tool_2",
1256            "word_list",
1257            "xxx_echo",
1258            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1259            "yyy_echo",
1260            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1261        ]
1262    );
1263}
1264
1265#[gpui::test]
1266#[cfg_attr(not(feature = "e2e"), ignore)]
1267async fn test_cancellation(cx: &mut TestAppContext) {
1268    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1269
1270    let mut events = thread
1271        .update(cx, |thread, cx| {
1272            thread.add_tool(InfiniteTool);
1273            thread.add_tool(EchoTool);
1274            thread.send(
1275                UserMessageId::new(),
1276                ["Call the echo tool, then call the infinite tool, then explain their output"],
1277                cx,
1278            )
1279        })
1280        .unwrap();
1281
1282    // Wait until both tools are called.
1283    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1284    let mut echo_id = None;
1285    let mut echo_completed = false;
1286    while let Some(event) = events.next().await {
1287        match event.unwrap() {
1288            ThreadEvent::ToolCall(tool_call) => {
1289                assert_eq!(tool_call.title, expected_tools.remove(0));
1290                if tool_call.title == "Echo" {
1291                    echo_id = Some(tool_call.id);
1292                }
1293            }
1294            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1295                acp::ToolCallUpdate {
1296                    id,
1297                    fields:
1298                        acp::ToolCallUpdateFields {
1299                            status: Some(acp::ToolCallStatus::Completed),
1300                            ..
1301                        },
1302                    meta: None,
1303                },
1304            )) if Some(&id) == echo_id.as_ref() => {
1305                echo_completed = true;
1306            }
1307            _ => {}
1308        }
1309
1310        if expected_tools.is_empty() && echo_completed {
1311            break;
1312        }
1313    }
1314
1315    // Cancel the current send and ensure that the event stream is closed, even
1316    // if one of the tools is still running.
1317    thread.update(cx, |thread, cx| thread.cancel(cx));
1318    let events = events.collect::<Vec<_>>().await;
1319    let last_event = events.last();
1320    assert!(
1321        matches!(
1322            last_event,
1323            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1324        ),
1325        "unexpected event {last_event:?}"
1326    );
1327
1328    // Ensure we can still send a new message after cancellation.
1329    let events = thread
1330        .update(cx, |thread, cx| {
1331            thread.send(
1332                UserMessageId::new(),
1333                ["Testing: reply with 'Hello' then stop."],
1334                cx,
1335            )
1336        })
1337        .unwrap()
1338        .collect::<Vec<_>>()
1339        .await;
1340    thread.update(cx, |thread, _cx| {
1341        let message = thread.last_message().unwrap();
1342        let agent_message = message.as_agent_message().unwrap();
1343        assert_eq!(
1344            agent_message.content,
1345            vec![AgentMessageContent::Text("Hello".to_string())]
1346        );
1347    });
1348    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1349}
1350
1351#[gpui::test]
1352async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1353    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1354    let fake_model = model.as_fake();
1355
1356    let events_1 = thread
1357        .update(cx, |thread, cx| {
1358            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1359        })
1360        .unwrap();
1361    cx.run_until_parked();
1362    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1363    cx.run_until_parked();
1364
1365    let events_2 = thread
1366        .update(cx, |thread, cx| {
1367            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1368        })
1369        .unwrap();
1370    cx.run_until_parked();
1371    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1372    fake_model
1373        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1374    fake_model.end_last_completion_stream();
1375
1376    let events_1 = events_1.collect::<Vec<_>>().await;
1377    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1378    let events_2 = events_2.collect::<Vec<_>>().await;
1379    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1380}
1381
1382#[gpui::test]
1383async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1384    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1385    let fake_model = model.as_fake();
1386
1387    let events_1 = thread
1388        .update(cx, |thread, cx| {
1389            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1390        })
1391        .unwrap();
1392    cx.run_until_parked();
1393    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1394    fake_model
1395        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1396    fake_model.end_last_completion_stream();
1397    let events_1 = events_1.collect::<Vec<_>>().await;
1398
1399    let events_2 = thread
1400        .update(cx, |thread, cx| {
1401            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1402        })
1403        .unwrap();
1404    cx.run_until_parked();
1405    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1406    fake_model
1407        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1408    fake_model.end_last_completion_stream();
1409    let events_2 = events_2.collect::<Vec<_>>().await;
1410
1411    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1412    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1413}
1414
1415#[gpui::test]
1416async fn test_refusal(cx: &mut TestAppContext) {
1417    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1418    let fake_model = model.as_fake();
1419
1420    let events = thread
1421        .update(cx, |thread, cx| {
1422            thread.send(UserMessageId::new(), ["Hello"], cx)
1423        })
1424        .unwrap();
1425    cx.run_until_parked();
1426    thread.read_with(cx, |thread, _| {
1427        assert_eq!(
1428            thread.to_markdown(),
1429            indoc! {"
1430                ## User
1431
1432                Hello
1433            "}
1434        );
1435    });
1436
1437    fake_model.send_last_completion_stream_text_chunk("Hey!");
1438    cx.run_until_parked();
1439    thread.read_with(cx, |thread, _| {
1440        assert_eq!(
1441            thread.to_markdown(),
1442            indoc! {"
1443                ## User
1444
1445                Hello
1446
1447                ## Assistant
1448
1449                Hey!
1450            "}
1451        );
1452    });
1453
1454    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1455    fake_model
1456        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1457    let events = events.collect::<Vec<_>>().await;
1458    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1459    thread.read_with(cx, |thread, _| {
1460        assert_eq!(thread.to_markdown(), "");
1461    });
1462}
1463
1464#[gpui::test]
1465async fn test_truncate_first_message(cx: &mut TestAppContext) {
1466    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1467    let fake_model = model.as_fake();
1468
1469    let message_id = UserMessageId::new();
1470    thread
1471        .update(cx, |thread, cx| {
1472            thread.send(message_id.clone(), ["Hello"], cx)
1473        })
1474        .unwrap();
1475    cx.run_until_parked();
1476    thread.read_with(cx, |thread, _| {
1477        assert_eq!(
1478            thread.to_markdown(),
1479            indoc! {"
1480                ## User
1481
1482                Hello
1483            "}
1484        );
1485        assert_eq!(thread.latest_token_usage(), None);
1486    });
1487
1488    fake_model.send_last_completion_stream_text_chunk("Hey!");
1489    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1490        language_model::TokenUsage {
1491            input_tokens: 32_000,
1492            output_tokens: 16_000,
1493            cache_creation_input_tokens: 0,
1494            cache_read_input_tokens: 0,
1495        },
1496    ));
1497    cx.run_until_parked();
1498    thread.read_with(cx, |thread, _| {
1499        assert_eq!(
1500            thread.to_markdown(),
1501            indoc! {"
1502                ## User
1503
1504                Hello
1505
1506                ## Assistant
1507
1508                Hey!
1509            "}
1510        );
1511        assert_eq!(
1512            thread.latest_token_usage(),
1513            Some(acp_thread::TokenUsage {
1514                used_tokens: 32_000 + 16_000,
1515                max_tokens: 1_000_000,
1516            })
1517        );
1518    });
1519
1520    thread
1521        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1522        .unwrap();
1523    cx.run_until_parked();
1524    thread.read_with(cx, |thread, _| {
1525        assert_eq!(thread.to_markdown(), "");
1526        assert_eq!(thread.latest_token_usage(), None);
1527    });
1528
1529    // Ensure we can still send a new message after truncation.
1530    thread
1531        .update(cx, |thread, cx| {
1532            thread.send(UserMessageId::new(), ["Hi"], cx)
1533        })
1534        .unwrap();
1535    thread.update(cx, |thread, _cx| {
1536        assert_eq!(
1537            thread.to_markdown(),
1538            indoc! {"
1539                ## User
1540
1541                Hi
1542            "}
1543        );
1544    });
1545    cx.run_until_parked();
1546    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1547    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1548        language_model::TokenUsage {
1549            input_tokens: 40_000,
1550            output_tokens: 20_000,
1551            cache_creation_input_tokens: 0,
1552            cache_read_input_tokens: 0,
1553        },
1554    ));
1555    cx.run_until_parked();
1556    thread.read_with(cx, |thread, _| {
1557        assert_eq!(
1558            thread.to_markdown(),
1559            indoc! {"
1560                ## User
1561
1562                Hi
1563
1564                ## Assistant
1565
1566                Ahoy!
1567            "}
1568        );
1569
1570        assert_eq!(
1571            thread.latest_token_usage(),
1572            Some(acp_thread::TokenUsage {
1573                used_tokens: 40_000 + 20_000,
1574                max_tokens: 1_000_000,
1575            })
1576        );
1577    });
1578}
1579
1580#[gpui::test]
1581async fn test_truncate_second_message(cx: &mut TestAppContext) {
1582    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1583    let fake_model = model.as_fake();
1584
1585    thread
1586        .update(cx, |thread, cx| {
1587            thread.send(UserMessageId::new(), ["Message 1"], cx)
1588        })
1589        .unwrap();
1590    cx.run_until_parked();
1591    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1592    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1593        language_model::TokenUsage {
1594            input_tokens: 32_000,
1595            output_tokens: 16_000,
1596            cache_creation_input_tokens: 0,
1597            cache_read_input_tokens: 0,
1598        },
1599    ));
1600    fake_model.end_last_completion_stream();
1601    cx.run_until_parked();
1602
1603    let assert_first_message_state = |cx: &mut TestAppContext| {
1604        thread.clone().read_with(cx, |thread, _| {
1605            assert_eq!(
1606                thread.to_markdown(),
1607                indoc! {"
1608                    ## User
1609
1610                    Message 1
1611
1612                    ## Assistant
1613
1614                    Message 1 response
1615                "}
1616            );
1617
1618            assert_eq!(
1619                thread.latest_token_usage(),
1620                Some(acp_thread::TokenUsage {
1621                    used_tokens: 32_000 + 16_000,
1622                    max_tokens: 1_000_000,
1623                })
1624            );
1625        });
1626    };
1627
1628    assert_first_message_state(cx);
1629
1630    let second_message_id = UserMessageId::new();
1631    thread
1632        .update(cx, |thread, cx| {
1633            thread.send(second_message_id.clone(), ["Message 2"], cx)
1634        })
1635        .unwrap();
1636    cx.run_until_parked();
1637
1638    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1639    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1640        language_model::TokenUsage {
1641            input_tokens: 40_000,
1642            output_tokens: 20_000,
1643            cache_creation_input_tokens: 0,
1644            cache_read_input_tokens: 0,
1645        },
1646    ));
1647    fake_model.end_last_completion_stream();
1648    cx.run_until_parked();
1649
1650    thread.read_with(cx, |thread, _| {
1651        assert_eq!(
1652            thread.to_markdown(),
1653            indoc! {"
1654                ## User
1655
1656                Message 1
1657
1658                ## Assistant
1659
1660                Message 1 response
1661
1662                ## User
1663
1664                Message 2
1665
1666                ## Assistant
1667
1668                Message 2 response
1669            "}
1670        );
1671
1672        assert_eq!(
1673            thread.latest_token_usage(),
1674            Some(acp_thread::TokenUsage {
1675                used_tokens: 40_000 + 20_000,
1676                max_tokens: 1_000_000,
1677            })
1678        );
1679    });
1680
1681    thread
1682        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1683        .unwrap();
1684    cx.run_until_parked();
1685
1686    assert_first_message_state(cx);
1687}
1688
1689#[gpui::test]
1690async fn test_title_generation(cx: &mut TestAppContext) {
1691    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1692    let fake_model = model.as_fake();
1693
1694    let summary_model = Arc::new(FakeLanguageModel::default());
1695    thread.update(cx, |thread, cx| {
1696        thread.set_summarization_model(Some(summary_model.clone()), cx)
1697    });
1698
1699    let send = thread
1700        .update(cx, |thread, cx| {
1701            thread.send(UserMessageId::new(), ["Hello"], cx)
1702        })
1703        .unwrap();
1704    cx.run_until_parked();
1705
1706    fake_model.send_last_completion_stream_text_chunk("Hey!");
1707    fake_model.end_last_completion_stream();
1708    cx.run_until_parked();
1709    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1710
1711    // Ensure the summary model has been invoked to generate a title.
1712    summary_model.send_last_completion_stream_text_chunk("Hello ");
1713    summary_model.send_last_completion_stream_text_chunk("world\nG");
1714    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1715    summary_model.end_last_completion_stream();
1716    send.collect::<Vec<_>>().await;
1717    cx.run_until_parked();
1718    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1719
1720    // Send another message, ensuring no title is generated this time.
1721    let send = thread
1722        .update(cx, |thread, cx| {
1723            thread.send(UserMessageId::new(), ["Hello again"], cx)
1724        })
1725        .unwrap();
1726    cx.run_until_parked();
1727    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1728    fake_model.end_last_completion_stream();
1729    cx.run_until_parked();
1730    assert_eq!(summary_model.pending_completions(), Vec::new());
1731    send.collect::<Vec<_>>().await;
1732    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1733}
1734
1735#[gpui::test]
1736async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1737    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1738    let fake_model = model.as_fake();
1739
1740    let _events = thread
1741        .update(cx, |thread, cx| {
1742            thread.add_tool(ToolRequiringPermission);
1743            thread.add_tool(EchoTool);
1744            thread.send(UserMessageId::new(), ["Hey!"], cx)
1745        })
1746        .unwrap();
1747    cx.run_until_parked();
1748
1749    let permission_tool_use = LanguageModelToolUse {
1750        id: "tool_id_1".into(),
1751        name: ToolRequiringPermission::name().into(),
1752        raw_input: "{}".into(),
1753        input: json!({}),
1754        is_input_complete: true,
1755    };
1756    let echo_tool_use = LanguageModelToolUse {
1757        id: "tool_id_2".into(),
1758        name: EchoTool::name().into(),
1759        raw_input: json!({"text": "test"}).to_string(),
1760        input: json!({"text": "test"}),
1761        is_input_complete: true,
1762    };
1763    fake_model.send_last_completion_stream_text_chunk("Hi!");
1764    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1765        permission_tool_use,
1766    ));
1767    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1768        echo_tool_use.clone(),
1769    ));
1770    fake_model.end_last_completion_stream();
1771    cx.run_until_parked();
1772
1773    // Ensure pending tools are skipped when building a request.
1774    let request = thread
1775        .read_with(cx, |thread, cx| {
1776            thread.build_completion_request(CompletionIntent::EditFile, cx)
1777        })
1778        .unwrap();
1779    assert_eq!(
1780        request.messages[1..],
1781        vec![
1782            LanguageModelRequestMessage {
1783                role: Role::User,
1784                content: vec!["Hey!".into()],
1785                cache: true
1786            },
1787            LanguageModelRequestMessage {
1788                role: Role::Assistant,
1789                content: vec![
1790                    MessageContent::Text("Hi!".into()),
1791                    MessageContent::ToolUse(echo_tool_use.clone())
1792                ],
1793                cache: false
1794            },
1795            LanguageModelRequestMessage {
1796                role: Role::User,
1797                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1798                    tool_use_id: echo_tool_use.id.clone(),
1799                    tool_name: echo_tool_use.name,
1800                    is_error: false,
1801                    content: "test".into(),
1802                    output: Some("test".into())
1803                })],
1804                cache: false
1805            },
1806        ],
1807    );
1808}
1809
1810#[gpui::test]
1811async fn test_agent_connection(cx: &mut TestAppContext) {
1812    cx.update(settings::init);
1813    let templates = Templates::new();
1814
1815    // Initialize language model system with test provider
1816    cx.update(|cx| {
1817        gpui_tokio::init(cx);
1818        client::init_settings(cx);
1819
1820        let http_client = FakeHttpClient::with_404_response();
1821        let clock = Arc::new(clock::FakeSystemClock::new());
1822        let client = Client::new(clock, http_client, cx);
1823        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1824        language_model::init(client.clone(), cx);
1825        language_models::init(user_store, client.clone(), cx);
1826        Project::init_settings(cx);
1827        LanguageModelRegistry::test(cx);
1828        agent_settings::init(cx);
1829    });
1830    cx.executor().forbid_parking();
1831
1832    // Create a project for new_thread
1833    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1834    fake_fs.insert_tree(path!("/test"), json!({})).await;
1835    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1836    let cwd = Path::new("/test");
1837    let text_thread_store =
1838        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1839    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1840
1841    // Create agent and connection
1842    let agent = NativeAgent::new(
1843        project.clone(),
1844        history_store,
1845        templates.clone(),
1846        None,
1847        fake_fs.clone(),
1848        &mut cx.to_async(),
1849    )
1850    .await
1851    .unwrap();
1852    let connection = NativeAgentConnection(agent.clone());
1853
1854    // Create a thread using new_thread
1855    let connection_rc = Rc::new(connection.clone());
1856    let acp_thread = cx
1857        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1858        .await
1859        .expect("new_thread should succeed");
1860
1861    // Get the session_id from the AcpThread
1862    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1863
1864    // Test model_selector returns Some
1865    let selector_opt = connection.model_selector(&session_id);
1866    assert!(
1867        selector_opt.is_some(),
1868        "agent should always support ModelSelector"
1869    );
1870    let selector = selector_opt.unwrap();
1871
1872    // Test list_models
1873    let listed_models = cx
1874        .update(|cx| selector.list_models(cx))
1875        .await
1876        .expect("list_models should succeed");
1877    let AgentModelList::Grouped(listed_models) = listed_models else {
1878        panic!("Unexpected model list type");
1879    };
1880    assert!(!listed_models.is_empty(), "should have at least one model");
1881    assert_eq!(
1882        listed_models[&AgentModelGroupName("Fake".into())][0]
1883            .id
1884            .0
1885            .as_ref(),
1886        "fake/fake"
1887    );
1888
1889    // Test selected_model returns the default
1890    let model = cx
1891        .update(|cx| selector.selected_model(cx))
1892        .await
1893        .expect("selected_model should succeed");
1894    let model = cx
1895        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1896        .unwrap();
1897    let model = model.as_fake();
1898    assert_eq!(model.id().0, "fake", "should return default model");
1899
1900    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1901    cx.run_until_parked();
1902    model.send_last_completion_stream_text_chunk("def");
1903    cx.run_until_parked();
1904    acp_thread.read_with(cx, |thread, cx| {
1905        assert_eq!(
1906            thread.to_markdown(cx),
1907            indoc! {"
1908                ## User
1909
1910                abc
1911
1912                ## Assistant
1913
1914                def
1915
1916            "}
1917        )
1918    });
1919
1920    // Test cancel
1921    cx.update(|cx| connection.cancel(&session_id, cx));
1922    request.await.expect("prompt should fail gracefully");
1923
1924    // Ensure that dropping the ACP thread causes the native thread to be
1925    // dropped as well.
1926    cx.update(|_| drop(acp_thread));
1927    let result = cx
1928        .update(|cx| {
1929            connection.prompt(
1930                Some(acp_thread::UserMessageId::new()),
1931                acp::PromptRequest {
1932                    session_id: session_id.clone(),
1933                    prompt: vec!["ghi".into()],
1934                    meta: None,
1935                },
1936                cx,
1937            )
1938        })
1939        .await;
1940    assert_eq!(
1941        result.as_ref().unwrap_err().to_string(),
1942        "Session not found",
1943        "unexpected result: {:?}",
1944        result
1945    );
1946}
1947
1948#[gpui::test]
1949async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1950    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1951    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1952    let fake_model = model.as_fake();
1953
1954    let mut events = thread
1955        .update(cx, |thread, cx| {
1956            thread.send(UserMessageId::new(), ["Think"], cx)
1957        })
1958        .unwrap();
1959    cx.run_until_parked();
1960
1961    // Simulate streaming partial input.
1962    let input = json!({});
1963    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1964        LanguageModelToolUse {
1965            id: "1".into(),
1966            name: ThinkingTool::name().into(),
1967            raw_input: input.to_string(),
1968            input,
1969            is_input_complete: false,
1970        },
1971    ));
1972
1973    // Input streaming completed
1974    let input = json!({ "content": "Thinking hard!" });
1975    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1976        LanguageModelToolUse {
1977            id: "1".into(),
1978            name: "thinking".into(),
1979            raw_input: input.to_string(),
1980            input,
1981            is_input_complete: true,
1982        },
1983    ));
1984    fake_model.end_last_completion_stream();
1985    cx.run_until_parked();
1986
1987    let tool_call = expect_tool_call(&mut events).await;
1988    assert_eq!(
1989        tool_call,
1990        acp::ToolCall {
1991            id: acp::ToolCallId("1".into()),
1992            title: "Thinking".into(),
1993            kind: acp::ToolKind::Think,
1994            status: acp::ToolCallStatus::Pending,
1995            content: vec![],
1996            locations: vec![],
1997            raw_input: Some(json!({})),
1998            raw_output: None,
1999            meta: Some(json!({ "tool_name": "thinking" })),
2000        }
2001    );
2002    let update = expect_tool_call_update_fields(&mut events).await;
2003    assert_eq!(
2004        update,
2005        acp::ToolCallUpdate {
2006            id: acp::ToolCallId("1".into()),
2007            fields: acp::ToolCallUpdateFields {
2008                title: Some("Thinking".into()),
2009                kind: Some(acp::ToolKind::Think),
2010                raw_input: Some(json!({ "content": "Thinking hard!" })),
2011                ..Default::default()
2012            },
2013            meta: None,
2014        }
2015    );
2016    let update = expect_tool_call_update_fields(&mut events).await;
2017    assert_eq!(
2018        update,
2019        acp::ToolCallUpdate {
2020            id: acp::ToolCallId("1".into()),
2021            fields: acp::ToolCallUpdateFields {
2022                status: Some(acp::ToolCallStatus::InProgress),
2023                ..Default::default()
2024            },
2025            meta: None,
2026        }
2027    );
2028    let update = expect_tool_call_update_fields(&mut events).await;
2029    assert_eq!(
2030        update,
2031        acp::ToolCallUpdate {
2032            id: acp::ToolCallId("1".into()),
2033            fields: acp::ToolCallUpdateFields {
2034                content: Some(vec!["Thinking hard!".into()]),
2035                ..Default::default()
2036            },
2037            meta: None,
2038        }
2039    );
2040    let update = expect_tool_call_update_fields(&mut events).await;
2041    assert_eq!(
2042        update,
2043        acp::ToolCallUpdate {
2044            id: acp::ToolCallId("1".into()),
2045            fields: acp::ToolCallUpdateFields {
2046                status: Some(acp::ToolCallStatus::Completed),
2047                raw_output: Some("Finished thinking.".into()),
2048                ..Default::default()
2049            },
2050            meta: None,
2051        }
2052    );
2053}
2054
2055#[gpui::test]
2056async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2057    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2058    let fake_model = model.as_fake();
2059
2060    let mut events = thread
2061        .update(cx, |thread, cx| {
2062            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2063            thread.send(UserMessageId::new(), ["Hello!"], cx)
2064        })
2065        .unwrap();
2066    cx.run_until_parked();
2067
2068    fake_model.send_last_completion_stream_text_chunk("Hey!");
2069    fake_model.end_last_completion_stream();
2070
2071    let mut retry_events = Vec::new();
2072    while let Some(Ok(event)) = events.next().await {
2073        match event {
2074            ThreadEvent::Retry(retry_status) => {
2075                retry_events.push(retry_status);
2076            }
2077            ThreadEvent::Stop(..) => break,
2078            _ => {}
2079        }
2080    }
2081
2082    assert_eq!(retry_events.len(), 0);
2083    thread.read_with(cx, |thread, _cx| {
2084        assert_eq!(
2085            thread.to_markdown(),
2086            indoc! {"
2087                ## User
2088
2089                Hello!
2090
2091                ## Assistant
2092
2093                Hey!
2094            "}
2095        )
2096    });
2097}
2098
2099#[gpui::test]
2100async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2101    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2102    let fake_model = model.as_fake();
2103
2104    let mut events = thread
2105        .update(cx, |thread, cx| {
2106            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2107            thread.send(UserMessageId::new(), ["Hello!"], cx)
2108        })
2109        .unwrap();
2110    cx.run_until_parked();
2111
2112    fake_model.send_last_completion_stream_text_chunk("Hey,");
2113    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2114        provider: LanguageModelProviderName::new("Anthropic"),
2115        retry_after: Some(Duration::from_secs(3)),
2116    });
2117    fake_model.end_last_completion_stream();
2118
2119    cx.executor().advance_clock(Duration::from_secs(3));
2120    cx.run_until_parked();
2121
2122    fake_model.send_last_completion_stream_text_chunk("there!");
2123    fake_model.end_last_completion_stream();
2124    cx.run_until_parked();
2125
2126    let mut retry_events = Vec::new();
2127    while let Some(Ok(event)) = events.next().await {
2128        match event {
2129            ThreadEvent::Retry(retry_status) => {
2130                retry_events.push(retry_status);
2131            }
2132            ThreadEvent::Stop(..) => break,
2133            _ => {}
2134        }
2135    }
2136
2137    assert_eq!(retry_events.len(), 1);
2138    assert!(matches!(
2139        retry_events[0],
2140        acp_thread::RetryStatus { attempt: 1, .. }
2141    ));
2142    thread.read_with(cx, |thread, _cx| {
2143        assert_eq!(
2144            thread.to_markdown(),
2145            indoc! {"
2146                ## User
2147
2148                Hello!
2149
2150                ## Assistant
2151
2152                Hey,
2153
2154                [resume]
2155
2156                ## Assistant
2157
2158                there!
2159            "}
2160        )
2161    });
2162}
2163
2164#[gpui::test]
2165async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2166    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2167    let fake_model = model.as_fake();
2168
2169    let events = thread
2170        .update(cx, |thread, cx| {
2171            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2172            thread.add_tool(EchoTool);
2173            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2174        })
2175        .unwrap();
2176    cx.run_until_parked();
2177
2178    let tool_use_1 = LanguageModelToolUse {
2179        id: "tool_1".into(),
2180        name: EchoTool::name().into(),
2181        raw_input: json!({"text": "test"}).to_string(),
2182        input: json!({"text": "test"}),
2183        is_input_complete: true,
2184    };
2185    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2186        tool_use_1.clone(),
2187    ));
2188    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2189        provider: LanguageModelProviderName::new("Anthropic"),
2190        retry_after: Some(Duration::from_secs(3)),
2191    });
2192    fake_model.end_last_completion_stream();
2193
2194    cx.executor().advance_clock(Duration::from_secs(3));
2195    let completion = fake_model.pending_completions().pop().unwrap();
2196    assert_eq!(
2197        completion.messages[1..],
2198        vec![
2199            LanguageModelRequestMessage {
2200                role: Role::User,
2201                content: vec!["Call the echo tool!".into()],
2202                cache: false
2203            },
2204            LanguageModelRequestMessage {
2205                role: Role::Assistant,
2206                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2207                cache: false
2208            },
2209            LanguageModelRequestMessage {
2210                role: Role::User,
2211                content: vec![language_model::MessageContent::ToolResult(
2212                    LanguageModelToolResult {
2213                        tool_use_id: tool_use_1.id.clone(),
2214                        tool_name: tool_use_1.name.clone(),
2215                        is_error: false,
2216                        content: "test".into(),
2217                        output: Some("test".into())
2218                    }
2219                )],
2220                cache: true
2221            },
2222        ]
2223    );
2224
2225    fake_model.send_last_completion_stream_text_chunk("Done");
2226    fake_model.end_last_completion_stream();
2227    cx.run_until_parked();
2228    events.collect::<Vec<_>>().await;
2229    thread.read_with(cx, |thread, _cx| {
2230        assert_eq!(
2231            thread.last_message(),
2232            Some(Message::Agent(AgentMessage {
2233                content: vec![AgentMessageContent::Text("Done".into())],
2234                tool_results: IndexMap::default()
2235            }))
2236        );
2237    })
2238}
2239
2240#[gpui::test]
2241async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2242    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2243    let fake_model = model.as_fake();
2244
2245    let mut events = thread
2246        .update(cx, |thread, cx| {
2247            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2248            thread.send(UserMessageId::new(), ["Hello!"], cx)
2249        })
2250        .unwrap();
2251    cx.run_until_parked();
2252
2253    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2254        fake_model.send_last_completion_stream_error(
2255            LanguageModelCompletionError::ServerOverloaded {
2256                provider: LanguageModelProviderName::new("Anthropic"),
2257                retry_after: Some(Duration::from_secs(3)),
2258            },
2259        );
2260        fake_model.end_last_completion_stream();
2261        cx.executor().advance_clock(Duration::from_secs(3));
2262        cx.run_until_parked();
2263    }
2264
2265    let mut errors = Vec::new();
2266    let mut retry_events = Vec::new();
2267    while let Some(event) = events.next().await {
2268        match event {
2269            Ok(ThreadEvent::Retry(retry_status)) => {
2270                retry_events.push(retry_status);
2271            }
2272            Ok(ThreadEvent::Stop(..)) => break,
2273            Err(error) => errors.push(error),
2274            _ => {}
2275        }
2276    }
2277
2278    assert_eq!(
2279        retry_events.len(),
2280        crate::thread::MAX_RETRY_ATTEMPTS as usize
2281    );
2282    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2283        assert_eq!(retry_events[i].attempt, i + 1);
2284    }
2285    assert_eq!(errors.len(), 1);
2286    let error = errors[0]
2287        .downcast_ref::<LanguageModelCompletionError>()
2288        .unwrap();
2289    assert!(matches!(
2290        error,
2291        LanguageModelCompletionError::ServerOverloaded { .. }
2292    ));
2293}
2294
2295/// Filters out the stop events for asserting against in tests
2296fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2297    result_events
2298        .into_iter()
2299        .filter_map(|event| match event.unwrap() {
2300            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2301            _ => None,
2302        })
2303        .collect()
2304}
2305
2306struct ThreadTest {
2307    model: Arc<dyn LanguageModel>,
2308    thread: Entity<Thread>,
2309    project_context: Entity<ProjectContext>,
2310    context_server_store: Entity<ContextServerStore>,
2311    fs: Arc<FakeFs>,
2312}
2313
2314enum TestModel {
2315    Sonnet4,
2316    Fake,
2317}
2318
2319impl TestModel {
2320    fn id(&self) -> LanguageModelId {
2321        match self {
2322            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2323            TestModel::Fake => unreachable!(),
2324        }
2325    }
2326}
2327
2328async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2329    cx.executor().allow_parking();
2330
2331    let fs = FakeFs::new(cx.background_executor.clone());
2332    fs.create_dir(paths::settings_file().parent().unwrap())
2333        .await
2334        .unwrap();
2335    fs.insert_file(
2336        paths::settings_file(),
2337        json!({
2338            "agent": {
2339                "default_profile": "test-profile",
2340                "profiles": {
2341                    "test-profile": {
2342                        "name": "Test Profile",
2343                        "tools": {
2344                            EchoTool::name(): true,
2345                            DelayTool::name(): true,
2346                            WordListTool::name(): true,
2347                            ToolRequiringPermission::name(): true,
2348                            InfiniteTool::name(): true,
2349                            ThinkingTool::name(): true,
2350                        }
2351                    }
2352                }
2353            }
2354        })
2355        .to_string()
2356        .into_bytes(),
2357    )
2358    .await;
2359
2360    cx.update(|cx| {
2361        settings::init(cx);
2362        Project::init_settings(cx);
2363        agent_settings::init(cx);
2364
2365        match model {
2366            TestModel::Fake => {}
2367            TestModel::Sonnet4 => {
2368                gpui_tokio::init(cx);
2369                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2370                cx.set_http_client(Arc::new(http_client));
2371                client::init_settings(cx);
2372                let client = Client::production(cx);
2373                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2374                language_model::init(client.clone(), cx);
2375                language_models::init(user_store, client.clone(), cx);
2376            }
2377        };
2378
2379        watch_settings(fs.clone(), cx);
2380    });
2381
2382    let templates = Templates::new();
2383
2384    fs.insert_tree(path!("/test"), json!({})).await;
2385    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2386
2387    let model = cx
2388        .update(|cx| {
2389            if let TestModel::Fake = model {
2390                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2391            } else {
2392                let model_id = model.id();
2393                let models = LanguageModelRegistry::read_global(cx);
2394                let model = models
2395                    .available_models(cx)
2396                    .find(|model| model.id() == model_id)
2397                    .unwrap();
2398
2399                let provider = models.provider(&model.provider_id()).unwrap();
2400                let authenticated = provider.authenticate(cx);
2401
2402                cx.spawn(async move |_cx| {
2403                    authenticated.await.unwrap();
2404                    model
2405                })
2406            }
2407        })
2408        .await;
2409
2410    let project_context = cx.new(|_cx| ProjectContext::default());
2411    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2412    let context_server_registry =
2413        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2414    let thread = cx.new(|cx| {
2415        Thread::new(
2416            project,
2417            project_context.clone(),
2418            context_server_registry,
2419            templates,
2420            Some(model.clone()),
2421            cx,
2422        )
2423    });
2424    ThreadTest {
2425        model,
2426        thread,
2427        project_context,
2428        context_server_store,
2429        fs,
2430    }
2431}
2432
2433#[cfg(test)]
2434#[ctor::ctor]
2435fn init_logger() {
2436    if std::env::var("RUST_LOG").is_ok() {
2437        env_logger::init();
2438    }
2439}
2440
2441fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2442    let fs = fs.clone();
2443    cx.spawn({
2444        async move |cx| {
2445            let mut new_settings_content_rx = settings::watch_config_file(
2446                cx.background_executor(),
2447                fs,
2448                paths::settings_file().clone(),
2449            );
2450
2451            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2452                cx.update(|cx| {
2453                    SettingsStore::update_global(cx, |settings, cx| {
2454                        settings.set_user_settings(&new_settings_content, cx)
2455                    })
2456                })
2457                .ok();
2458            }
2459        }
2460    })
2461    .detach();
2462}
2463
2464fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2465    completion
2466        .tools
2467        .iter()
2468        .map(|tool| tool.name.clone())
2469        .collect()
2470}
2471
2472fn setup_context_server(
2473    name: &'static str,
2474    tools: Vec<context_server::types::Tool>,
2475    context_server_store: &Entity<ContextServerStore>,
2476    cx: &mut TestAppContext,
2477) -> mpsc::UnboundedReceiver<(
2478    context_server::types::CallToolParams,
2479    oneshot::Sender<context_server::types::CallToolResponse>,
2480)> {
2481    cx.update(|cx| {
2482        let mut settings = ProjectSettings::get_global(cx).clone();
2483        settings.context_servers.insert(
2484            name.into(),
2485            project::project_settings::ContextServerSettings::Custom {
2486                enabled: true,
2487                command: ContextServerCommand {
2488                    path: "somebinary".into(),
2489                    args: Vec::new(),
2490                    env: None,
2491                    timeout: None,
2492                },
2493            },
2494        );
2495        ProjectSettings::override_global(settings, cx);
2496    });
2497
2498    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2499    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2500        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2501            context_server::types::InitializeResponse {
2502                protocol_version: context_server::types::ProtocolVersion(
2503                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2504                ),
2505                server_info: context_server::types::Implementation {
2506                    name: name.into(),
2507                    version: "1.0.0".to_string(),
2508                },
2509                capabilities: context_server::types::ServerCapabilities {
2510                    tools: Some(context_server::types::ToolsCapabilities {
2511                        list_changed: Some(true),
2512                    }),
2513                    ..Default::default()
2514                },
2515                meta: None,
2516            }
2517        })
2518        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2519            let tools = tools.clone();
2520            async move {
2521                context_server::types::ListToolsResponse {
2522                    tools,
2523                    next_cursor: None,
2524                    meta: None,
2525                }
2526            }
2527        })
2528        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2529            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2530            async move {
2531                let (response_tx, response_rx) = oneshot::channel();
2532                mcp_tool_calls_tx
2533                    .unbounded_send((params, response_tx))
2534                    .unwrap();
2535                response_rx.await.unwrap()
2536            }
2537        });
2538    context_server_store.update(cx, |store, cx| {
2539        store.start_server(
2540            Arc::new(ContextServer::new(
2541                ContextServerId(name.into()),
2542                Arc::new(fake_transport),
2543            )),
2544            cx,
2545        );
2546    });
2547    cx.run_until_parked();
2548    mcp_tool_calls_rx
2549}