mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true,
 219            reasoning_details: None,
 220        }]
 221    );
 222    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 223        "Response to Message 1".into(),
 224    ));
 225    fake_model.end_last_completion_stream();
 226    cx.run_until_parked();
 227
 228    // Send another user message and verify only the latest is cached
 229    thread
 230        .update(cx, |thread, cx| {
 231            thread.send(UserMessageId::new(), ["Message 2"], cx)
 232        })
 233        .unwrap();
 234    cx.run_until_parked();
 235
 236    let completion = fake_model.pending_completions().pop().unwrap();
 237    assert_eq!(
 238        completion.messages[1..],
 239        vec![
 240            LanguageModelRequestMessage {
 241                role: Role::User,
 242                content: vec!["Message 1".into()],
 243                cache: false,
 244                reasoning_details: None,
 245            },
 246            LanguageModelRequestMessage {
 247                role: Role::Assistant,
 248                content: vec!["Response to Message 1".into()],
 249                cache: false,
 250                reasoning_details: None,
 251            },
 252            LanguageModelRequestMessage {
 253                role: Role::User,
 254                content: vec!["Message 2".into()],
 255                cache: true,
 256                reasoning_details: None,
 257            }
 258        ]
 259    );
 260    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 261        "Response to Message 2".into(),
 262    ));
 263    fake_model.end_last_completion_stream();
 264    cx.run_until_parked();
 265
 266    // Simulate a tool call and verify that the latest tool result is cached
 267    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 268    thread
 269        .update(cx, |thread, cx| {
 270            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 271        })
 272        .unwrap();
 273    cx.run_until_parked();
 274
 275    let tool_use = LanguageModelToolUse {
 276        id: "tool_1".into(),
 277        name: EchoTool::name().into(),
 278        raw_input: json!({"text": "test"}).to_string(),
 279        input: json!({"text": "test"}),
 280        is_input_complete: true,
 281        thought_signature: None,
 282    };
 283    fake_model
 284        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 285    fake_model.end_last_completion_stream();
 286    cx.run_until_parked();
 287
 288    let completion = fake_model.pending_completions().pop().unwrap();
 289    let tool_result = LanguageModelToolResult {
 290        tool_use_id: "tool_1".into(),
 291        tool_name: EchoTool::name().into(),
 292        is_error: false,
 293        content: "test".into(),
 294        output: Some("test".into()),
 295    };
 296    assert_eq!(
 297        completion.messages[1..],
 298        vec![
 299            LanguageModelRequestMessage {
 300                role: Role::User,
 301                content: vec!["Message 1".into()],
 302                cache: false,
 303                reasoning_details: None,
 304            },
 305            LanguageModelRequestMessage {
 306                role: Role::Assistant,
 307                content: vec!["Response to Message 1".into()],
 308                cache: false,
 309                reasoning_details: None,
 310            },
 311            LanguageModelRequestMessage {
 312                role: Role::User,
 313                content: vec!["Message 2".into()],
 314                cache: false,
 315                reasoning_details: None,
 316            },
 317            LanguageModelRequestMessage {
 318                role: Role::Assistant,
 319                content: vec!["Response to Message 2".into()],
 320                cache: false,
 321                reasoning_details: None,
 322            },
 323            LanguageModelRequestMessage {
 324                role: Role::User,
 325                content: vec!["Use the echo tool".into()],
 326                cache: false,
 327                reasoning_details: None,
 328            },
 329            LanguageModelRequestMessage {
 330                role: Role::Assistant,
 331                content: vec![MessageContent::ToolUse(tool_use)],
 332                cache: false,
 333                reasoning_details: None,
 334            },
 335            LanguageModelRequestMessage {
 336                role: Role::User,
 337                content: vec![MessageContent::ToolResult(tool_result)],
 338                cache: true,
 339                reasoning_details: None,
 340            }
 341        ]
 342    );
 343}
 344
 345#[gpui::test]
 346#[cfg_attr(not(feature = "e2e"), ignore)]
 347async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 348    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 349
 350    // Test a tool call that's likely to complete *before* streaming stops.
 351    let events = thread
 352        .update(cx, |thread, cx| {
 353            thread.add_tool(EchoTool);
 354            thread.send(
 355                UserMessageId::new(),
 356                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 357                cx,
 358            )
 359        })
 360        .unwrap()
 361        .collect()
 362        .await;
 363    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 364
 365    // Test a tool calls that's likely to complete *after* streaming stops.
 366    let events = thread
 367        .update(cx, |thread, cx| {
 368            thread.remove_tool(&EchoTool::name());
 369            thread.add_tool(DelayTool);
 370            thread.send(
 371                UserMessageId::new(),
 372                [
 373                    "Now call the delay tool with 200ms.",
 374                    "When the timer goes off, then you echo the output of the tool.",
 375                ],
 376                cx,
 377            )
 378        })
 379        .unwrap()
 380        .collect()
 381        .await;
 382    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 383    thread.update(cx, |thread, _cx| {
 384        assert!(
 385            thread
 386                .last_message()
 387                .unwrap()
 388                .as_agent_message()
 389                .unwrap()
 390                .content
 391                .iter()
 392                .any(|content| {
 393                    if let AgentMessageContent::Text(text) = content {
 394                        text.contains("Ding")
 395                    } else {
 396                        false
 397                    }
 398                }),
 399            "{}",
 400            thread.to_markdown()
 401        );
 402    });
 403}
 404
 405#[gpui::test]
 406#[cfg_attr(not(feature = "e2e"), ignore)]
 407async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 408    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 409
 410    // Test a tool call that's likely to complete *before* streaming stops.
 411    let mut events = thread
 412        .update(cx, |thread, cx| {
 413            thread.add_tool(WordListTool);
 414            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 415        })
 416        .unwrap();
 417
 418    let mut saw_partial_tool_use = false;
 419    while let Some(event) = events.next().await {
 420        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 421            thread.update(cx, |thread, _cx| {
 422                // Look for a tool use in the thread's last message
 423                let message = thread.last_message().unwrap();
 424                let agent_message = message.as_agent_message().unwrap();
 425                let last_content = agent_message.content.last().unwrap();
 426                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 427                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 428                    if tool_call.status == acp::ToolCallStatus::Pending {
 429                        if !last_tool_use.is_input_complete
 430                            && last_tool_use.input.get("g").is_none()
 431                        {
 432                            saw_partial_tool_use = true;
 433                        }
 434                    } else {
 435                        last_tool_use
 436                            .input
 437                            .get("a")
 438                            .expect("'a' has streamed because input is now complete");
 439                        last_tool_use
 440                            .input
 441                            .get("g")
 442                            .expect("'g' has streamed because input is now complete");
 443                    }
 444                } else {
 445                    panic!("last content should be a tool use");
 446                }
 447            });
 448        }
 449    }
 450
 451    assert!(
 452        saw_partial_tool_use,
 453        "should see at least one partially streamed tool use in the history"
 454    );
 455}
 456
 457#[gpui::test]
 458async fn test_tool_authorization(cx: &mut TestAppContext) {
 459    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 460    let fake_model = model.as_fake();
 461
 462    let mut events = thread
 463        .update(cx, |thread, cx| {
 464            thread.add_tool(ToolRequiringPermission);
 465            thread.send(UserMessageId::new(), ["abc"], cx)
 466        })
 467        .unwrap();
 468    cx.run_until_parked();
 469    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 470        LanguageModelToolUse {
 471            id: "tool_id_1".into(),
 472            name: ToolRequiringPermission::name().into(),
 473            raw_input: "{}".into(),
 474            input: json!({}),
 475            is_input_complete: true,
 476            thought_signature: None,
 477        },
 478    ));
 479    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 480        LanguageModelToolUse {
 481            id: "tool_id_2".into(),
 482            name: ToolRequiringPermission::name().into(),
 483            raw_input: "{}".into(),
 484            input: json!({}),
 485            is_input_complete: true,
 486            thought_signature: None,
 487        },
 488    ));
 489    fake_model.end_last_completion_stream();
 490    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 491    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 492
 493    // Approve the first
 494    tool_call_auth_1
 495        .response
 496        .send(tool_call_auth_1.options[1].id.clone())
 497        .unwrap();
 498    cx.run_until_parked();
 499
 500    // Reject the second
 501    tool_call_auth_2
 502        .response
 503        .send(tool_call_auth_1.options[2].id.clone())
 504        .unwrap();
 505    cx.run_until_parked();
 506
 507    let completion = fake_model.pending_completions().pop().unwrap();
 508    let message = completion.messages.last().unwrap();
 509    assert_eq!(
 510        message.content,
 511        vec![
 512            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 513                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 514                tool_name: ToolRequiringPermission::name().into(),
 515                is_error: false,
 516                content: "Allowed".into(),
 517                output: Some("Allowed".into())
 518            }),
 519            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 520                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 521                tool_name: ToolRequiringPermission::name().into(),
 522                is_error: true,
 523                content: "Permission to run tool denied by user".into(),
 524                output: Some("Permission to run tool denied by user".into())
 525            })
 526        ]
 527    );
 528
 529    // Simulate yet another tool call.
 530    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 531        LanguageModelToolUse {
 532            id: "tool_id_3".into(),
 533            name: ToolRequiringPermission::name().into(),
 534            raw_input: "{}".into(),
 535            input: json!({}),
 536            is_input_complete: true,
 537            thought_signature: None,
 538        },
 539    ));
 540    fake_model.end_last_completion_stream();
 541
 542    // Respond by always allowing tools.
 543    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 544    tool_call_auth_3
 545        .response
 546        .send(tool_call_auth_3.options[0].id.clone())
 547        .unwrap();
 548    cx.run_until_parked();
 549    let completion = fake_model.pending_completions().pop().unwrap();
 550    let message = completion.messages.last().unwrap();
 551    assert_eq!(
 552        message.content,
 553        vec![language_model::MessageContent::ToolResult(
 554            LanguageModelToolResult {
 555                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 556                tool_name: ToolRequiringPermission::name().into(),
 557                is_error: false,
 558                content: "Allowed".into(),
 559                output: Some("Allowed".into())
 560            }
 561        )]
 562    );
 563
 564    // Simulate a final tool call, ensuring we don't trigger authorization.
 565    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 566        LanguageModelToolUse {
 567            id: "tool_id_4".into(),
 568            name: ToolRequiringPermission::name().into(),
 569            raw_input: "{}".into(),
 570            input: json!({}),
 571            is_input_complete: true,
 572            thought_signature: None,
 573        },
 574    ));
 575    fake_model.end_last_completion_stream();
 576    cx.run_until_parked();
 577    let completion = fake_model.pending_completions().pop().unwrap();
 578    let message = completion.messages.last().unwrap();
 579    assert_eq!(
 580        message.content,
 581        vec![language_model::MessageContent::ToolResult(
 582            LanguageModelToolResult {
 583                tool_use_id: "tool_id_4".into(),
 584                tool_name: ToolRequiringPermission::name().into(),
 585                is_error: false,
 586                content: "Allowed".into(),
 587                output: Some("Allowed".into())
 588            }
 589        )]
 590    );
 591}
 592
 593#[gpui::test]
 594async fn test_tool_hallucination(cx: &mut TestAppContext) {
 595    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 596    let fake_model = model.as_fake();
 597
 598    let mut events = thread
 599        .update(cx, |thread, cx| {
 600            thread.send(UserMessageId::new(), ["abc"], cx)
 601        })
 602        .unwrap();
 603    cx.run_until_parked();
 604    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 605        LanguageModelToolUse {
 606            id: "tool_id_1".into(),
 607            name: "nonexistent_tool".into(),
 608            raw_input: "{}".into(),
 609            input: json!({}),
 610            is_input_complete: true,
 611            thought_signature: None,
 612        },
 613    ));
 614    fake_model.end_last_completion_stream();
 615
 616    let tool_call = expect_tool_call(&mut events).await;
 617    assert_eq!(tool_call.title, "nonexistent_tool");
 618    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 619    let update = expect_tool_call_update_fields(&mut events).await;
 620    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 621}
 622
 623#[gpui::test]
 624async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 625    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 626    let fake_model = model.as_fake();
 627
 628    let events = thread
 629        .update(cx, |thread, cx| {
 630            thread.add_tool(EchoTool);
 631            thread.send(UserMessageId::new(), ["abc"], cx)
 632        })
 633        .unwrap();
 634    cx.run_until_parked();
 635    let tool_use = LanguageModelToolUse {
 636        id: "tool_id_1".into(),
 637        name: EchoTool::name().into(),
 638        raw_input: "{}".into(),
 639        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 640        is_input_complete: true,
 641        thought_signature: None,
 642    };
 643    fake_model
 644        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 645    fake_model.end_last_completion_stream();
 646
 647    cx.run_until_parked();
 648    let completion = fake_model.pending_completions().pop().unwrap();
 649    let tool_result = LanguageModelToolResult {
 650        tool_use_id: "tool_id_1".into(),
 651        tool_name: EchoTool::name().into(),
 652        is_error: false,
 653        content: "def".into(),
 654        output: Some("def".into()),
 655    };
 656    assert_eq!(
 657        completion.messages[1..],
 658        vec![
 659            LanguageModelRequestMessage {
 660                role: Role::User,
 661                content: vec!["abc".into()],
 662                cache: false,
 663                reasoning_details: None,
 664            },
 665            LanguageModelRequestMessage {
 666                role: Role::Assistant,
 667                content: vec![MessageContent::ToolUse(tool_use.clone())],
 668                cache: false,
 669                reasoning_details: None,
 670            },
 671            LanguageModelRequestMessage {
 672                role: Role::User,
 673                content: vec![MessageContent::ToolResult(tool_result.clone())],
 674                cache: true,
 675                reasoning_details: None,
 676            },
 677        ]
 678    );
 679
 680    // Simulate reaching tool use limit.
 681    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 682        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 683    ));
 684    fake_model.end_last_completion_stream();
 685    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 686    assert!(
 687        last_event
 688            .unwrap_err()
 689            .is::<language_model::ToolUseLimitReachedError>()
 690    );
 691
 692    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 693    cx.run_until_parked();
 694    let completion = fake_model.pending_completions().pop().unwrap();
 695    assert_eq!(
 696        completion.messages[1..],
 697        vec![
 698            LanguageModelRequestMessage {
 699                role: Role::User,
 700                content: vec!["abc".into()],
 701                cache: false,
 702                reasoning_details: None,
 703            },
 704            LanguageModelRequestMessage {
 705                role: Role::Assistant,
 706                content: vec![MessageContent::ToolUse(tool_use)],
 707                cache: false,
 708                reasoning_details: None,
 709            },
 710            LanguageModelRequestMessage {
 711                role: Role::User,
 712                content: vec![MessageContent::ToolResult(tool_result)],
 713                cache: false,
 714                reasoning_details: None,
 715            },
 716            LanguageModelRequestMessage {
 717                role: Role::User,
 718                content: vec!["Continue where you left off".into()],
 719                cache: true,
 720                reasoning_details: None,
 721            }
 722        ]
 723    );
 724
 725    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 726    fake_model.end_last_completion_stream();
 727    events.collect::<Vec<_>>().await;
 728    thread.read_with(cx, |thread, _cx| {
 729        assert_eq!(
 730            thread.last_message().unwrap().to_markdown(),
 731            indoc! {"
 732                ## Assistant
 733
 734                Done
 735            "}
 736        )
 737    });
 738}
 739
 740#[gpui::test]
 741async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 742    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 743    let fake_model = model.as_fake();
 744
 745    let events = thread
 746        .update(cx, |thread, cx| {
 747            thread.add_tool(EchoTool);
 748            thread.send(UserMessageId::new(), ["abc"], cx)
 749        })
 750        .unwrap();
 751    cx.run_until_parked();
 752
 753    let tool_use = LanguageModelToolUse {
 754        id: "tool_id_1".into(),
 755        name: EchoTool::name().into(),
 756        raw_input: "{}".into(),
 757        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 758        is_input_complete: true,
 759        thought_signature: None,
 760    };
 761    let tool_result = LanguageModelToolResult {
 762        tool_use_id: "tool_id_1".into(),
 763        tool_name: EchoTool::name().into(),
 764        is_error: false,
 765        content: "def".into(),
 766        output: Some("def".into()),
 767    };
 768    fake_model
 769        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 770    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 771        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 772    ));
 773    fake_model.end_last_completion_stream();
 774    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 775    assert!(
 776        last_event
 777            .unwrap_err()
 778            .is::<language_model::ToolUseLimitReachedError>()
 779    );
 780
 781    thread
 782        .update(cx, |thread, cx| {
 783            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 784        })
 785        .unwrap();
 786    cx.run_until_parked();
 787    let completion = fake_model.pending_completions().pop().unwrap();
 788    assert_eq!(
 789        completion.messages[1..],
 790        vec![
 791            LanguageModelRequestMessage {
 792                role: Role::User,
 793                content: vec!["abc".into()],
 794                cache: false,
 795                reasoning_details: None,
 796            },
 797            LanguageModelRequestMessage {
 798                role: Role::Assistant,
 799                content: vec![MessageContent::ToolUse(tool_use)],
 800                cache: false,
 801                reasoning_details: None,
 802            },
 803            LanguageModelRequestMessage {
 804                role: Role::User,
 805                content: vec![MessageContent::ToolResult(tool_result)],
 806                cache: false,
 807                reasoning_details: None,
 808            },
 809            LanguageModelRequestMessage {
 810                role: Role::User,
 811                content: vec!["ghi".into()],
 812                cache: true,
 813                reasoning_details: None,
 814            }
 815        ]
 816    );
 817}
 818
 819async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 820    let event = events
 821        .next()
 822        .await
 823        .expect("no tool call authorization event received")
 824        .unwrap();
 825    match event {
 826        ThreadEvent::ToolCall(tool_call) => tool_call,
 827        event => {
 828            panic!("Unexpected event {event:?}");
 829        }
 830    }
 831}
 832
 833async fn expect_tool_call_update_fields(
 834    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 835) -> acp::ToolCallUpdate {
 836    let event = events
 837        .next()
 838        .await
 839        .expect("no tool call authorization event received")
 840        .unwrap();
 841    match event {
 842        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 843        event => {
 844            panic!("Unexpected event {event:?}");
 845        }
 846    }
 847}
 848
 849async fn next_tool_call_authorization(
 850    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 851) -> ToolCallAuthorization {
 852    loop {
 853        let event = events
 854            .next()
 855            .await
 856            .expect("no tool call authorization event received")
 857            .unwrap();
 858        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 859            let permission_kinds = tool_call_authorization
 860                .options
 861                .iter()
 862                .map(|o| o.kind)
 863                .collect::<Vec<_>>();
 864            assert_eq!(
 865                permission_kinds,
 866                vec![
 867                    acp::PermissionOptionKind::AllowAlways,
 868                    acp::PermissionOptionKind::AllowOnce,
 869                    acp::PermissionOptionKind::RejectOnce,
 870                ]
 871            );
 872            return tool_call_authorization;
 873        }
 874    }
 875}
 876
 877#[gpui::test]
 878#[cfg_attr(not(feature = "e2e"), ignore)]
 879async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 880    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 881
 882    // Test concurrent tool calls with different delay times
 883    let events = thread
 884        .update(cx, |thread, cx| {
 885            thread.add_tool(DelayTool);
 886            thread.send(
 887                UserMessageId::new(),
 888                [
 889                    "Call the delay tool twice in the same message.",
 890                    "Once with 100ms. Once with 300ms.",
 891                    "When both timers are complete, describe the outputs.",
 892                ],
 893                cx,
 894            )
 895        })
 896        .unwrap()
 897        .collect()
 898        .await;
 899
 900    let stop_reasons = stop_events(events);
 901    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 902
 903    thread.update(cx, |thread, _cx| {
 904        let last_message = thread.last_message().unwrap();
 905        let agent_message = last_message.as_agent_message().unwrap();
 906        let text = agent_message
 907            .content
 908            .iter()
 909            .filter_map(|content| {
 910                if let AgentMessageContent::Text(text) = content {
 911                    Some(text.as_str())
 912                } else {
 913                    None
 914                }
 915            })
 916            .collect::<String>();
 917
 918        assert!(text.contains("Ding"));
 919    });
 920}
 921
 922#[gpui::test]
 923async fn test_profiles(cx: &mut TestAppContext) {
 924    let ThreadTest {
 925        model, thread, fs, ..
 926    } = setup(cx, TestModel::Fake).await;
 927    let fake_model = model.as_fake();
 928
 929    thread.update(cx, |thread, _cx| {
 930        thread.add_tool(DelayTool);
 931        thread.add_tool(EchoTool);
 932        thread.add_tool(InfiniteTool);
 933    });
 934
 935    // Override profiles and wait for settings to be loaded.
 936    fs.insert_file(
 937        paths::settings_file(),
 938        json!({
 939            "agent": {
 940                "profiles": {
 941                    "test-1": {
 942                        "name": "Test Profile 1",
 943                        "tools": {
 944                            EchoTool::name(): true,
 945                            DelayTool::name(): true,
 946                        }
 947                    },
 948                    "test-2": {
 949                        "name": "Test Profile 2",
 950                        "tools": {
 951                            InfiniteTool::name(): true,
 952                        }
 953                    }
 954                }
 955            }
 956        })
 957        .to_string()
 958        .into_bytes(),
 959    )
 960    .await;
 961    cx.run_until_parked();
 962
 963    // Test that test-1 profile (default) has echo and delay tools
 964    thread
 965        .update(cx, |thread, cx| {
 966            thread.set_profile(AgentProfileId("test-1".into()), cx);
 967            thread.send(UserMessageId::new(), ["test"], cx)
 968        })
 969        .unwrap();
 970    cx.run_until_parked();
 971
 972    let mut pending_completions = fake_model.pending_completions();
 973    assert_eq!(pending_completions.len(), 1);
 974    let completion = pending_completions.pop().unwrap();
 975    let tool_names: Vec<String> = completion
 976        .tools
 977        .iter()
 978        .map(|tool| tool.name.clone())
 979        .collect();
 980    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 981    fake_model.end_last_completion_stream();
 982
 983    // Switch to test-2 profile, and verify that it has only the infinite tool.
 984    thread
 985        .update(cx, |thread, cx| {
 986            thread.set_profile(AgentProfileId("test-2".into()), cx);
 987            thread.send(UserMessageId::new(), ["test2"], cx)
 988        })
 989        .unwrap();
 990    cx.run_until_parked();
 991    let mut pending_completions = fake_model.pending_completions();
 992    assert_eq!(pending_completions.len(), 1);
 993    let completion = pending_completions.pop().unwrap();
 994    let tool_names: Vec<String> = completion
 995        .tools
 996        .iter()
 997        .map(|tool| tool.name.clone())
 998        .collect();
 999    assert_eq!(tool_names, vec![InfiniteTool::name()]);
1000}
1001
1002#[gpui::test]
1003async fn test_mcp_tools(cx: &mut TestAppContext) {
1004    let ThreadTest {
1005        model,
1006        thread,
1007        context_server_store,
1008        fs,
1009        ..
1010    } = setup(cx, TestModel::Fake).await;
1011    let fake_model = model.as_fake();
1012
1013    // Override profiles and wait for settings to be loaded.
1014    fs.insert_file(
1015        paths::settings_file(),
1016        json!({
1017            "agent": {
1018                "always_allow_tool_actions": true,
1019                "profiles": {
1020                    "test": {
1021                        "name": "Test Profile",
1022                        "enable_all_context_servers": true,
1023                        "tools": {
1024                            EchoTool::name(): true,
1025                        }
1026                    },
1027                }
1028            }
1029        })
1030        .to_string()
1031        .into_bytes(),
1032    )
1033    .await;
1034    cx.run_until_parked();
1035    thread.update(cx, |thread, cx| {
1036        thread.set_profile(AgentProfileId("test".into()), cx)
1037    });
1038
1039    let mut mcp_tool_calls = setup_context_server(
1040        "test_server",
1041        vec![context_server::types::Tool {
1042            name: "echo".into(),
1043            description: None,
1044            input_schema: serde_json::to_value(EchoTool::input_schema(
1045                LanguageModelToolSchemaFormat::JsonSchema,
1046            ))
1047            .unwrap(),
1048            output_schema: None,
1049            annotations: None,
1050        }],
1051        &context_server_store,
1052        cx,
1053    );
1054
1055    let events = thread.update(cx, |thread, cx| {
1056        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1057    });
1058    cx.run_until_parked();
1059
1060    // Simulate the model calling the MCP tool.
1061    let completion = fake_model.pending_completions().pop().unwrap();
1062    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1063    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1064        LanguageModelToolUse {
1065            id: "tool_1".into(),
1066            name: "echo".into(),
1067            raw_input: json!({"text": "test"}).to_string(),
1068            input: json!({"text": "test"}),
1069            is_input_complete: true,
1070            thought_signature: None,
1071        },
1072    ));
1073    fake_model.end_last_completion_stream();
1074    cx.run_until_parked();
1075
1076    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1077    assert_eq!(tool_call_params.name, "echo");
1078    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1079    tool_call_response
1080        .send(context_server::types::CallToolResponse {
1081            content: vec![context_server::types::ToolResponseContent::Text {
1082                text: "test".into(),
1083            }],
1084            is_error: None,
1085            meta: None,
1086            structured_content: None,
1087        })
1088        .unwrap();
1089    cx.run_until_parked();
1090
1091    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1092    fake_model.send_last_completion_stream_text_chunk("Done!");
1093    fake_model.end_last_completion_stream();
1094    events.collect::<Vec<_>>().await;
1095
1096    // Send again after adding the echo tool, ensuring the name collision is resolved.
1097    let events = thread.update(cx, |thread, cx| {
1098        thread.add_tool(EchoTool);
1099        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1100    });
1101    cx.run_until_parked();
1102    let completion = fake_model.pending_completions().pop().unwrap();
1103    assert_eq!(
1104        tool_names_for_completion(&completion),
1105        vec!["echo", "test_server_echo"]
1106    );
1107    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1108        LanguageModelToolUse {
1109            id: "tool_2".into(),
1110            name: "test_server_echo".into(),
1111            raw_input: json!({"text": "mcp"}).to_string(),
1112            input: json!({"text": "mcp"}),
1113            is_input_complete: true,
1114            thought_signature: None,
1115        },
1116    ));
1117    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1118        LanguageModelToolUse {
1119            id: "tool_3".into(),
1120            name: "echo".into(),
1121            raw_input: json!({"text": "native"}).to_string(),
1122            input: json!({"text": "native"}),
1123            is_input_complete: true,
1124            thought_signature: None,
1125        },
1126    ));
1127    fake_model.end_last_completion_stream();
1128    cx.run_until_parked();
1129
1130    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1131    assert_eq!(tool_call_params.name, "echo");
1132    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1133    tool_call_response
1134        .send(context_server::types::CallToolResponse {
1135            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1136            is_error: None,
1137            meta: None,
1138            structured_content: None,
1139        })
1140        .unwrap();
1141    cx.run_until_parked();
1142
1143    // Ensure the tool results were inserted with the correct names.
1144    let completion = fake_model.pending_completions().pop().unwrap();
1145    assert_eq!(
1146        completion.messages.last().unwrap().content,
1147        vec![
1148            MessageContent::ToolResult(LanguageModelToolResult {
1149                tool_use_id: "tool_3".into(),
1150                tool_name: "echo".into(),
1151                is_error: false,
1152                content: "native".into(),
1153                output: Some("native".into()),
1154            },),
1155            MessageContent::ToolResult(LanguageModelToolResult {
1156                tool_use_id: "tool_2".into(),
1157                tool_name: "test_server_echo".into(),
1158                is_error: false,
1159                content: "mcp".into(),
1160                output: Some("mcp".into()),
1161            },),
1162        ]
1163    );
1164    fake_model.end_last_completion_stream();
1165    events.collect::<Vec<_>>().await;
1166}
1167
1168#[gpui::test]
1169async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1170    let ThreadTest {
1171        model,
1172        thread,
1173        context_server_store,
1174        fs,
1175        ..
1176    } = setup(cx, TestModel::Fake).await;
1177    let fake_model = model.as_fake();
1178
1179    // Set up a profile with all tools enabled
1180    fs.insert_file(
1181        paths::settings_file(),
1182        json!({
1183            "agent": {
1184                "profiles": {
1185                    "test": {
1186                        "name": "Test Profile",
1187                        "enable_all_context_servers": true,
1188                        "tools": {
1189                            EchoTool::name(): true,
1190                            DelayTool::name(): true,
1191                            WordListTool::name(): true,
1192                            ToolRequiringPermission::name(): true,
1193                            InfiniteTool::name(): true,
1194                        }
1195                    },
1196                }
1197            }
1198        })
1199        .to_string()
1200        .into_bytes(),
1201    )
1202    .await;
1203    cx.run_until_parked();
1204
1205    thread.update(cx, |thread, cx| {
1206        thread.set_profile(AgentProfileId("test".into()), cx);
1207        thread.add_tool(EchoTool);
1208        thread.add_tool(DelayTool);
1209        thread.add_tool(WordListTool);
1210        thread.add_tool(ToolRequiringPermission);
1211        thread.add_tool(InfiniteTool);
1212    });
1213
1214    // Set up multiple context servers with some overlapping tool names
1215    let _server1_calls = setup_context_server(
1216        "xxx",
1217        vec![
1218            context_server::types::Tool {
1219                name: "echo".into(), // Conflicts with native EchoTool
1220                description: None,
1221                input_schema: serde_json::to_value(EchoTool::input_schema(
1222                    LanguageModelToolSchemaFormat::JsonSchema,
1223                ))
1224                .unwrap(),
1225                output_schema: None,
1226                annotations: None,
1227            },
1228            context_server::types::Tool {
1229                name: "unique_tool_1".into(),
1230                description: None,
1231                input_schema: json!({"type": "object", "properties": {}}),
1232                output_schema: None,
1233                annotations: None,
1234            },
1235        ],
1236        &context_server_store,
1237        cx,
1238    );
1239
1240    let _server2_calls = setup_context_server(
1241        "yyy",
1242        vec![
1243            context_server::types::Tool {
1244                name: "echo".into(), // Also conflicts with native EchoTool
1245                description: None,
1246                input_schema: serde_json::to_value(EchoTool::input_schema(
1247                    LanguageModelToolSchemaFormat::JsonSchema,
1248                ))
1249                .unwrap(),
1250                output_schema: None,
1251                annotations: None,
1252            },
1253            context_server::types::Tool {
1254                name: "unique_tool_2".into(),
1255                description: None,
1256                input_schema: json!({"type": "object", "properties": {}}),
1257                output_schema: None,
1258                annotations: None,
1259            },
1260            context_server::types::Tool {
1261                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1262                description: None,
1263                input_schema: json!({"type": "object", "properties": {}}),
1264                output_schema: None,
1265                annotations: None,
1266            },
1267            context_server::types::Tool {
1268                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1269                description: None,
1270                input_schema: json!({"type": "object", "properties": {}}),
1271                output_schema: None,
1272                annotations: None,
1273            },
1274        ],
1275        &context_server_store,
1276        cx,
1277    );
1278    let _server3_calls = setup_context_server(
1279        "zzz",
1280        vec![
1281            context_server::types::Tool {
1282                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1283                description: None,
1284                input_schema: json!({"type": "object", "properties": {}}),
1285                output_schema: None,
1286                annotations: None,
1287            },
1288            context_server::types::Tool {
1289                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1290                description: None,
1291                input_schema: json!({"type": "object", "properties": {}}),
1292                output_schema: None,
1293                annotations: None,
1294            },
1295            context_server::types::Tool {
1296                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1297                description: None,
1298                input_schema: json!({"type": "object", "properties": {}}),
1299                output_schema: None,
1300                annotations: None,
1301            },
1302        ],
1303        &context_server_store,
1304        cx,
1305    );
1306
1307    thread
1308        .update(cx, |thread, cx| {
1309            thread.send(UserMessageId::new(), ["Go"], cx)
1310        })
1311        .unwrap();
1312    cx.run_until_parked();
1313    let completion = fake_model.pending_completions().pop().unwrap();
1314    assert_eq!(
1315        tool_names_for_completion(&completion),
1316        vec![
1317            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1318            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1319            "delay",
1320            "echo",
1321            "infinite",
1322            "tool_requiring_permission",
1323            "unique_tool_1",
1324            "unique_tool_2",
1325            "word_list",
1326            "xxx_echo",
1327            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1328            "yyy_echo",
1329            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1330        ]
1331    );
1332}
1333
1334#[gpui::test]
1335#[cfg_attr(not(feature = "e2e"), ignore)]
1336async fn test_cancellation(cx: &mut TestAppContext) {
1337    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1338
1339    let mut events = thread
1340        .update(cx, |thread, cx| {
1341            thread.add_tool(InfiniteTool);
1342            thread.add_tool(EchoTool);
1343            thread.send(
1344                UserMessageId::new(),
1345                ["Call the echo tool, then call the infinite tool, then explain their output"],
1346                cx,
1347            )
1348        })
1349        .unwrap();
1350
1351    // Wait until both tools are called.
1352    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1353    let mut echo_id = None;
1354    let mut echo_completed = false;
1355    while let Some(event) = events.next().await {
1356        match event.unwrap() {
1357            ThreadEvent::ToolCall(tool_call) => {
1358                assert_eq!(tool_call.title, expected_tools.remove(0));
1359                if tool_call.title == "Echo" {
1360                    echo_id = Some(tool_call.id);
1361                }
1362            }
1363            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1364                acp::ToolCallUpdate {
1365                    id,
1366                    fields:
1367                        acp::ToolCallUpdateFields {
1368                            status: Some(acp::ToolCallStatus::Completed),
1369                            ..
1370                        },
1371                    meta: None,
1372                },
1373            )) if Some(&id) == echo_id.as_ref() => {
1374                echo_completed = true;
1375            }
1376            _ => {}
1377        }
1378
1379        if expected_tools.is_empty() && echo_completed {
1380            break;
1381        }
1382    }
1383
1384    // Cancel the current send and ensure that the event stream is closed, even
1385    // if one of the tools is still running.
1386    thread.update(cx, |thread, cx| thread.cancel(cx));
1387    let events = events.collect::<Vec<_>>().await;
1388    let last_event = events.last();
1389    assert!(
1390        matches!(
1391            last_event,
1392            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1393        ),
1394        "unexpected event {last_event:?}"
1395    );
1396
1397    // Ensure we can still send a new message after cancellation.
1398    let events = thread
1399        .update(cx, |thread, cx| {
1400            thread.send(
1401                UserMessageId::new(),
1402                ["Testing: reply with 'Hello' then stop."],
1403                cx,
1404            )
1405        })
1406        .unwrap()
1407        .collect::<Vec<_>>()
1408        .await;
1409    thread.update(cx, |thread, _cx| {
1410        let message = thread.last_message().unwrap();
1411        let agent_message = message.as_agent_message().unwrap();
1412        assert_eq!(
1413            agent_message.content,
1414            vec![AgentMessageContent::Text("Hello".to_string())]
1415        );
1416    });
1417    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1418}
1419
1420#[gpui::test]
1421async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1422    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1423    let fake_model = model.as_fake();
1424
1425    let events_1 = thread
1426        .update(cx, |thread, cx| {
1427            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1428        })
1429        .unwrap();
1430    cx.run_until_parked();
1431    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1432    cx.run_until_parked();
1433
1434    let events_2 = thread
1435        .update(cx, |thread, cx| {
1436            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1437        })
1438        .unwrap();
1439    cx.run_until_parked();
1440    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1441    fake_model
1442        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1443    fake_model.end_last_completion_stream();
1444
1445    let events_1 = events_1.collect::<Vec<_>>().await;
1446    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1447    let events_2 = events_2.collect::<Vec<_>>().await;
1448    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1449}
1450
1451#[gpui::test]
1452async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1453    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1454    let fake_model = model.as_fake();
1455
1456    let events_1 = thread
1457        .update(cx, |thread, cx| {
1458            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1459        })
1460        .unwrap();
1461    cx.run_until_parked();
1462    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1463    fake_model
1464        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1465    fake_model.end_last_completion_stream();
1466    let events_1 = events_1.collect::<Vec<_>>().await;
1467
1468    let events_2 = thread
1469        .update(cx, |thread, cx| {
1470            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1471        })
1472        .unwrap();
1473    cx.run_until_parked();
1474    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1475    fake_model
1476        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1477    fake_model.end_last_completion_stream();
1478    let events_2 = events_2.collect::<Vec<_>>().await;
1479
1480    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1481    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1482}
1483
1484#[gpui::test]
1485async fn test_refusal(cx: &mut TestAppContext) {
1486    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1487    let fake_model = model.as_fake();
1488
1489    let events = thread
1490        .update(cx, |thread, cx| {
1491            thread.send(UserMessageId::new(), ["Hello"], cx)
1492        })
1493        .unwrap();
1494    cx.run_until_parked();
1495    thread.read_with(cx, |thread, _| {
1496        assert_eq!(
1497            thread.to_markdown(),
1498            indoc! {"
1499                ## User
1500
1501                Hello
1502            "}
1503        );
1504    });
1505
1506    fake_model.send_last_completion_stream_text_chunk("Hey!");
1507    cx.run_until_parked();
1508    thread.read_with(cx, |thread, _| {
1509        assert_eq!(
1510            thread.to_markdown(),
1511            indoc! {"
1512                ## User
1513
1514                Hello
1515
1516                ## Assistant
1517
1518                Hey!
1519            "}
1520        );
1521    });
1522
1523    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1524    fake_model
1525        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1526    let events = events.collect::<Vec<_>>().await;
1527    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1528    thread.read_with(cx, |thread, _| {
1529        assert_eq!(thread.to_markdown(), "");
1530    });
1531}
1532
1533#[gpui::test]
1534async fn test_truncate_first_message(cx: &mut TestAppContext) {
1535    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1536    let fake_model = model.as_fake();
1537
1538    let message_id = UserMessageId::new();
1539    thread
1540        .update(cx, |thread, cx| {
1541            thread.send(message_id.clone(), ["Hello"], cx)
1542        })
1543        .unwrap();
1544    cx.run_until_parked();
1545    thread.read_with(cx, |thread, _| {
1546        assert_eq!(
1547            thread.to_markdown(),
1548            indoc! {"
1549                ## User
1550
1551                Hello
1552            "}
1553        );
1554        assert_eq!(thread.latest_token_usage(), None);
1555    });
1556
1557    fake_model.send_last_completion_stream_text_chunk("Hey!");
1558    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1559        language_model::TokenUsage {
1560            input_tokens: 32_000,
1561            output_tokens: 16_000,
1562            cache_creation_input_tokens: 0,
1563            cache_read_input_tokens: 0,
1564        },
1565    ));
1566    cx.run_until_parked();
1567    thread.read_with(cx, |thread, _| {
1568        assert_eq!(
1569            thread.to_markdown(),
1570            indoc! {"
1571                ## User
1572
1573                Hello
1574
1575                ## Assistant
1576
1577                Hey!
1578            "}
1579        );
1580        assert_eq!(
1581            thread.latest_token_usage(),
1582            Some(acp_thread::TokenUsage {
1583                used_tokens: 32_000 + 16_000,
1584                max_tokens: 1_000_000,
1585            })
1586        );
1587    });
1588
1589    thread
1590        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1591        .unwrap();
1592    cx.run_until_parked();
1593    thread.read_with(cx, |thread, _| {
1594        assert_eq!(thread.to_markdown(), "");
1595        assert_eq!(thread.latest_token_usage(), None);
1596    });
1597
1598    // Ensure we can still send a new message after truncation.
1599    thread
1600        .update(cx, |thread, cx| {
1601            thread.send(UserMessageId::new(), ["Hi"], cx)
1602        })
1603        .unwrap();
1604    thread.update(cx, |thread, _cx| {
1605        assert_eq!(
1606            thread.to_markdown(),
1607            indoc! {"
1608                ## User
1609
1610                Hi
1611            "}
1612        );
1613    });
1614    cx.run_until_parked();
1615    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1616    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1617        language_model::TokenUsage {
1618            input_tokens: 40_000,
1619            output_tokens: 20_000,
1620            cache_creation_input_tokens: 0,
1621            cache_read_input_tokens: 0,
1622        },
1623    ));
1624    cx.run_until_parked();
1625    thread.read_with(cx, |thread, _| {
1626        assert_eq!(
1627            thread.to_markdown(),
1628            indoc! {"
1629                ## User
1630
1631                Hi
1632
1633                ## Assistant
1634
1635                Ahoy!
1636            "}
1637        );
1638
1639        assert_eq!(
1640            thread.latest_token_usage(),
1641            Some(acp_thread::TokenUsage {
1642                used_tokens: 40_000 + 20_000,
1643                max_tokens: 1_000_000,
1644            })
1645        );
1646    });
1647}
1648
1649#[gpui::test]
1650async fn test_truncate_second_message(cx: &mut TestAppContext) {
1651    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1652    let fake_model = model.as_fake();
1653
1654    thread
1655        .update(cx, |thread, cx| {
1656            thread.send(UserMessageId::new(), ["Message 1"], cx)
1657        })
1658        .unwrap();
1659    cx.run_until_parked();
1660    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1661    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1662        language_model::TokenUsage {
1663            input_tokens: 32_000,
1664            output_tokens: 16_000,
1665            cache_creation_input_tokens: 0,
1666            cache_read_input_tokens: 0,
1667        },
1668    ));
1669    fake_model.end_last_completion_stream();
1670    cx.run_until_parked();
1671
1672    let assert_first_message_state = |cx: &mut TestAppContext| {
1673        thread.clone().read_with(cx, |thread, _| {
1674            assert_eq!(
1675                thread.to_markdown(),
1676                indoc! {"
1677                    ## User
1678
1679                    Message 1
1680
1681                    ## Assistant
1682
1683                    Message 1 response
1684                "}
1685            );
1686
1687            assert_eq!(
1688                thread.latest_token_usage(),
1689                Some(acp_thread::TokenUsage {
1690                    used_tokens: 32_000 + 16_000,
1691                    max_tokens: 1_000_000,
1692                })
1693            );
1694        });
1695    };
1696
1697    assert_first_message_state(cx);
1698
1699    let second_message_id = UserMessageId::new();
1700    thread
1701        .update(cx, |thread, cx| {
1702            thread.send(second_message_id.clone(), ["Message 2"], cx)
1703        })
1704        .unwrap();
1705    cx.run_until_parked();
1706
1707    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1708    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1709        language_model::TokenUsage {
1710            input_tokens: 40_000,
1711            output_tokens: 20_000,
1712            cache_creation_input_tokens: 0,
1713            cache_read_input_tokens: 0,
1714        },
1715    ));
1716    fake_model.end_last_completion_stream();
1717    cx.run_until_parked();
1718
1719    thread.read_with(cx, |thread, _| {
1720        assert_eq!(
1721            thread.to_markdown(),
1722            indoc! {"
1723                ## User
1724
1725                Message 1
1726
1727                ## Assistant
1728
1729                Message 1 response
1730
1731                ## User
1732
1733                Message 2
1734
1735                ## Assistant
1736
1737                Message 2 response
1738            "}
1739        );
1740
1741        assert_eq!(
1742            thread.latest_token_usage(),
1743            Some(acp_thread::TokenUsage {
1744                used_tokens: 40_000 + 20_000,
1745                max_tokens: 1_000_000,
1746            })
1747        );
1748    });
1749
1750    thread
1751        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1752        .unwrap();
1753    cx.run_until_parked();
1754
1755    assert_first_message_state(cx);
1756}
1757
1758#[gpui::test]
1759async fn test_title_generation(cx: &mut TestAppContext) {
1760    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1761    let fake_model = model.as_fake();
1762
1763    let summary_model = Arc::new(FakeLanguageModel::default());
1764    thread.update(cx, |thread, cx| {
1765        thread.set_summarization_model(Some(summary_model.clone()), cx)
1766    });
1767
1768    let send = thread
1769        .update(cx, |thread, cx| {
1770            thread.send(UserMessageId::new(), ["Hello"], cx)
1771        })
1772        .unwrap();
1773    cx.run_until_parked();
1774
1775    fake_model.send_last_completion_stream_text_chunk("Hey!");
1776    fake_model.end_last_completion_stream();
1777    cx.run_until_parked();
1778    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1779
1780    // Ensure the summary model has been invoked to generate a title.
1781    summary_model.send_last_completion_stream_text_chunk("Hello ");
1782    summary_model.send_last_completion_stream_text_chunk("world\nG");
1783    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1784    summary_model.end_last_completion_stream();
1785    send.collect::<Vec<_>>().await;
1786    cx.run_until_parked();
1787    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1788
1789    // Send another message, ensuring no title is generated this time.
1790    let send = thread
1791        .update(cx, |thread, cx| {
1792            thread.send(UserMessageId::new(), ["Hello again"], cx)
1793        })
1794        .unwrap();
1795    cx.run_until_parked();
1796    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1797    fake_model.end_last_completion_stream();
1798    cx.run_until_parked();
1799    assert_eq!(summary_model.pending_completions(), Vec::new());
1800    send.collect::<Vec<_>>().await;
1801    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1802}
1803
1804#[gpui::test]
1805async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1806    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1807    let fake_model = model.as_fake();
1808
1809    let _events = thread
1810        .update(cx, |thread, cx| {
1811            thread.add_tool(ToolRequiringPermission);
1812            thread.add_tool(EchoTool);
1813            thread.send(UserMessageId::new(), ["Hey!"], cx)
1814        })
1815        .unwrap();
1816    cx.run_until_parked();
1817
1818    let permission_tool_use = LanguageModelToolUse {
1819        id: "tool_id_1".into(),
1820        name: ToolRequiringPermission::name().into(),
1821        raw_input: "{}".into(),
1822        input: json!({}),
1823        is_input_complete: true,
1824        thought_signature: None,
1825    };
1826    let echo_tool_use = LanguageModelToolUse {
1827        id: "tool_id_2".into(),
1828        name: EchoTool::name().into(),
1829        raw_input: json!({"text": "test"}).to_string(),
1830        input: json!({"text": "test"}),
1831        is_input_complete: true,
1832        thought_signature: None,
1833    };
1834    fake_model.send_last_completion_stream_text_chunk("Hi!");
1835    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1836        permission_tool_use,
1837    ));
1838    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1839        echo_tool_use.clone(),
1840    ));
1841    fake_model.end_last_completion_stream();
1842    cx.run_until_parked();
1843
1844    // Ensure pending tools are skipped when building a request.
1845    let request = thread
1846        .read_with(cx, |thread, cx| {
1847            thread.build_completion_request(CompletionIntent::EditFile, cx)
1848        })
1849        .unwrap();
1850    assert_eq!(
1851        request.messages[1..],
1852        vec![
1853            LanguageModelRequestMessage {
1854                role: Role::User,
1855                content: vec!["Hey!".into()],
1856                cache: true,
1857                reasoning_details: None,
1858            },
1859            LanguageModelRequestMessage {
1860                role: Role::Assistant,
1861                content: vec![
1862                    MessageContent::Text("Hi!".into()),
1863                    MessageContent::ToolUse(echo_tool_use.clone())
1864                ],
1865                cache: false,
1866                reasoning_details: None,
1867            },
1868            LanguageModelRequestMessage {
1869                role: Role::User,
1870                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1871                    tool_use_id: echo_tool_use.id.clone(),
1872                    tool_name: echo_tool_use.name,
1873                    is_error: false,
1874                    content: "test".into(),
1875                    output: Some("test".into())
1876                })],
1877                cache: false,
1878                reasoning_details: None,
1879            },
1880        ],
1881    );
1882}
1883
1884#[gpui::test]
1885async fn test_agent_connection(cx: &mut TestAppContext) {
1886    cx.update(settings::init);
1887    let templates = Templates::new();
1888
1889    // Initialize language model system with test provider
1890    cx.update(|cx| {
1891        gpui_tokio::init(cx);
1892
1893        let http_client = FakeHttpClient::with_404_response();
1894        let clock = Arc::new(clock::FakeSystemClock::new());
1895        let client = Client::new(clock, http_client, cx);
1896        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1897        language_model::init(client.clone(), cx);
1898        language_models::init(user_store, client.clone(), cx);
1899        LanguageModelRegistry::test(cx);
1900    });
1901    cx.executor().forbid_parking();
1902
1903    // Create a project for new_thread
1904    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1905    fake_fs.insert_tree(path!("/test"), json!({})).await;
1906    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1907    let cwd = Path::new("/test");
1908    let text_thread_store =
1909        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1910    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1911
1912    // Create agent and connection
1913    let agent = NativeAgent::new(
1914        project.clone(),
1915        history_store,
1916        templates.clone(),
1917        None,
1918        fake_fs.clone(),
1919        &mut cx.to_async(),
1920    )
1921    .await
1922    .unwrap();
1923    let connection = NativeAgentConnection(agent.clone());
1924
1925    // Create a thread using new_thread
1926    let connection_rc = Rc::new(connection.clone());
1927    let acp_thread = cx
1928        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1929        .await
1930        .expect("new_thread should succeed");
1931
1932    // Get the session_id from the AcpThread
1933    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1934
1935    // Test model_selector returns Some
1936    let selector_opt = connection.model_selector(&session_id);
1937    assert!(
1938        selector_opt.is_some(),
1939        "agent should always support ModelSelector"
1940    );
1941    let selector = selector_opt.unwrap();
1942
1943    // Test list_models
1944    let listed_models = cx
1945        .update(|cx| selector.list_models(cx))
1946        .await
1947        .expect("list_models should succeed");
1948    let AgentModelList::Grouped(listed_models) = listed_models else {
1949        panic!("Unexpected model list type");
1950    };
1951    assert!(!listed_models.is_empty(), "should have at least one model");
1952    assert_eq!(
1953        listed_models[&AgentModelGroupName("Fake".into())][0]
1954            .id
1955            .0
1956            .as_ref(),
1957        "fake/fake"
1958    );
1959
1960    // Test selected_model returns the default
1961    let model = cx
1962        .update(|cx| selector.selected_model(cx))
1963        .await
1964        .expect("selected_model should succeed");
1965    let model = cx
1966        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1967        .unwrap();
1968    let model = model.as_fake();
1969    assert_eq!(model.id().0, "fake", "should return default model");
1970
1971    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1972    cx.run_until_parked();
1973    model.send_last_completion_stream_text_chunk("def");
1974    cx.run_until_parked();
1975    acp_thread.read_with(cx, |thread, cx| {
1976        assert_eq!(
1977            thread.to_markdown(cx),
1978            indoc! {"
1979                ## User
1980
1981                abc
1982
1983                ## Assistant
1984
1985                def
1986
1987            "}
1988        )
1989    });
1990
1991    // Test cancel
1992    cx.update(|cx| connection.cancel(&session_id, cx));
1993    request.await.expect("prompt should fail gracefully");
1994
1995    // Ensure that dropping the ACP thread causes the native thread to be
1996    // dropped as well.
1997    cx.update(|_| drop(acp_thread));
1998    let result = cx
1999        .update(|cx| {
2000            connection.prompt(
2001                Some(acp_thread::UserMessageId::new()),
2002                acp::PromptRequest {
2003                    session_id: session_id.clone(),
2004                    prompt: vec!["ghi".into()],
2005                    meta: None,
2006                },
2007                cx,
2008            )
2009        })
2010        .await;
2011    assert_eq!(
2012        result.as_ref().unwrap_err().to_string(),
2013        "Session not found",
2014        "unexpected result: {:?}",
2015        result
2016    );
2017}
2018
2019#[gpui::test]
2020async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2021    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2022    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2023    let fake_model = model.as_fake();
2024
2025    let mut events = thread
2026        .update(cx, |thread, cx| {
2027            thread.send(UserMessageId::new(), ["Think"], cx)
2028        })
2029        .unwrap();
2030    cx.run_until_parked();
2031
2032    // Simulate streaming partial input.
2033    let input = json!({});
2034    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2035        LanguageModelToolUse {
2036            id: "1".into(),
2037            name: ThinkingTool::name().into(),
2038            raw_input: input.to_string(),
2039            input,
2040            is_input_complete: false,
2041            thought_signature: None,
2042        },
2043    ));
2044
2045    // Input streaming completed
2046    let input = json!({ "content": "Thinking hard!" });
2047    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2048        LanguageModelToolUse {
2049            id: "1".into(),
2050            name: "thinking".into(),
2051            raw_input: input.to_string(),
2052            input,
2053            is_input_complete: true,
2054            thought_signature: None,
2055        },
2056    ));
2057    fake_model.end_last_completion_stream();
2058    cx.run_until_parked();
2059
2060    let tool_call = expect_tool_call(&mut events).await;
2061    assert_eq!(
2062        tool_call,
2063        acp::ToolCall {
2064            id: acp::ToolCallId("1".into()),
2065            title: "Thinking".into(),
2066            kind: acp::ToolKind::Think,
2067            status: acp::ToolCallStatus::Pending,
2068            content: vec![],
2069            locations: vec![],
2070            raw_input: Some(json!({})),
2071            raw_output: None,
2072            meta: Some(json!({ "tool_name": "thinking" })),
2073        }
2074    );
2075    let update = expect_tool_call_update_fields(&mut events).await;
2076    assert_eq!(
2077        update,
2078        acp::ToolCallUpdate {
2079            id: acp::ToolCallId("1".into()),
2080            fields: acp::ToolCallUpdateFields {
2081                title: Some("Thinking".into()),
2082                kind: Some(acp::ToolKind::Think),
2083                raw_input: Some(json!({ "content": "Thinking hard!" })),
2084                ..Default::default()
2085            },
2086            meta: None,
2087        }
2088    );
2089    let update = expect_tool_call_update_fields(&mut events).await;
2090    assert_eq!(
2091        update,
2092        acp::ToolCallUpdate {
2093            id: acp::ToolCallId("1".into()),
2094            fields: acp::ToolCallUpdateFields {
2095                status: Some(acp::ToolCallStatus::InProgress),
2096                ..Default::default()
2097            },
2098            meta: None,
2099        }
2100    );
2101    let update = expect_tool_call_update_fields(&mut events).await;
2102    assert_eq!(
2103        update,
2104        acp::ToolCallUpdate {
2105            id: acp::ToolCallId("1".into()),
2106            fields: acp::ToolCallUpdateFields {
2107                content: Some(vec!["Thinking hard!".into()]),
2108                ..Default::default()
2109            },
2110            meta: None,
2111        }
2112    );
2113    let update = expect_tool_call_update_fields(&mut events).await;
2114    assert_eq!(
2115        update,
2116        acp::ToolCallUpdate {
2117            id: acp::ToolCallId("1".into()),
2118            fields: acp::ToolCallUpdateFields {
2119                status: Some(acp::ToolCallStatus::Completed),
2120                raw_output: Some("Finished thinking.".into()),
2121                ..Default::default()
2122            },
2123            meta: None,
2124        }
2125    );
2126}
2127
2128#[gpui::test]
2129async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2130    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2131    let fake_model = model.as_fake();
2132
2133    let mut events = thread
2134        .update(cx, |thread, cx| {
2135            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2136            thread.send(UserMessageId::new(), ["Hello!"], cx)
2137        })
2138        .unwrap();
2139    cx.run_until_parked();
2140
2141    fake_model.send_last_completion_stream_text_chunk("Hey!");
2142    fake_model.end_last_completion_stream();
2143
2144    let mut retry_events = Vec::new();
2145    while let Some(Ok(event)) = events.next().await {
2146        match event {
2147            ThreadEvent::Retry(retry_status) => {
2148                retry_events.push(retry_status);
2149            }
2150            ThreadEvent::Stop(..) => break,
2151            _ => {}
2152        }
2153    }
2154
2155    assert_eq!(retry_events.len(), 0);
2156    thread.read_with(cx, |thread, _cx| {
2157        assert_eq!(
2158            thread.to_markdown(),
2159            indoc! {"
2160                ## User
2161
2162                Hello!
2163
2164                ## Assistant
2165
2166                Hey!
2167            "}
2168        )
2169    });
2170}
2171
2172#[gpui::test]
2173async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2174    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2175    let fake_model = model.as_fake();
2176
2177    let mut events = thread
2178        .update(cx, |thread, cx| {
2179            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2180            thread.send(UserMessageId::new(), ["Hello!"], cx)
2181        })
2182        .unwrap();
2183    cx.run_until_parked();
2184
2185    fake_model.send_last_completion_stream_text_chunk("Hey,");
2186    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2187        provider: LanguageModelProviderName::new("Anthropic"),
2188        retry_after: Some(Duration::from_secs(3)),
2189    });
2190    fake_model.end_last_completion_stream();
2191
2192    cx.executor().advance_clock(Duration::from_secs(3));
2193    cx.run_until_parked();
2194
2195    fake_model.send_last_completion_stream_text_chunk("there!");
2196    fake_model.end_last_completion_stream();
2197    cx.run_until_parked();
2198
2199    let mut retry_events = Vec::new();
2200    while let Some(Ok(event)) = events.next().await {
2201        match event {
2202            ThreadEvent::Retry(retry_status) => {
2203                retry_events.push(retry_status);
2204            }
2205            ThreadEvent::Stop(..) => break,
2206            _ => {}
2207        }
2208    }
2209
2210    assert_eq!(retry_events.len(), 1);
2211    assert!(matches!(
2212        retry_events[0],
2213        acp_thread::RetryStatus { attempt: 1, .. }
2214    ));
2215    thread.read_with(cx, |thread, _cx| {
2216        assert_eq!(
2217            thread.to_markdown(),
2218            indoc! {"
2219                ## User
2220
2221                Hello!
2222
2223                ## Assistant
2224
2225                Hey,
2226
2227                [resume]
2228
2229                ## Assistant
2230
2231                there!
2232            "}
2233        )
2234    });
2235}
2236
2237#[gpui::test]
2238async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2239    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2240    let fake_model = model.as_fake();
2241
2242    let events = thread
2243        .update(cx, |thread, cx| {
2244            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2245            thread.add_tool(EchoTool);
2246            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2247        })
2248        .unwrap();
2249    cx.run_until_parked();
2250
2251    let tool_use_1 = LanguageModelToolUse {
2252        id: "tool_1".into(),
2253        name: EchoTool::name().into(),
2254        raw_input: json!({"text": "test"}).to_string(),
2255        input: json!({"text": "test"}),
2256        is_input_complete: true,
2257        thought_signature: None,
2258    };
2259    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2260        tool_use_1.clone(),
2261    ));
2262    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2263        provider: LanguageModelProviderName::new("Anthropic"),
2264        retry_after: Some(Duration::from_secs(3)),
2265    });
2266    fake_model.end_last_completion_stream();
2267
2268    cx.executor().advance_clock(Duration::from_secs(3));
2269    let completion = fake_model.pending_completions().pop().unwrap();
2270    assert_eq!(
2271        completion.messages[1..],
2272        vec![
2273            LanguageModelRequestMessage {
2274                role: Role::User,
2275                content: vec!["Call the echo tool!".into()],
2276                cache: false,
2277                reasoning_details: None,
2278            },
2279            LanguageModelRequestMessage {
2280                role: Role::Assistant,
2281                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2282                cache: false,
2283                reasoning_details: None,
2284            },
2285            LanguageModelRequestMessage {
2286                role: Role::User,
2287                content: vec![language_model::MessageContent::ToolResult(
2288                    LanguageModelToolResult {
2289                        tool_use_id: tool_use_1.id.clone(),
2290                        tool_name: tool_use_1.name.clone(),
2291                        is_error: false,
2292                        content: "test".into(),
2293                        output: Some("test".into())
2294                    }
2295                )],
2296                cache: true,
2297                reasoning_details: None,
2298            },
2299        ]
2300    );
2301
2302    fake_model.send_last_completion_stream_text_chunk("Done");
2303    fake_model.end_last_completion_stream();
2304    cx.run_until_parked();
2305    events.collect::<Vec<_>>().await;
2306    thread.read_with(cx, |thread, _cx| {
2307        assert_eq!(
2308            thread.last_message(),
2309            Some(Message::Agent(AgentMessage {
2310                content: vec![AgentMessageContent::Text("Done".into())],
2311                tool_results: IndexMap::default(),
2312                reasoning_details: None,
2313            }))
2314        );
2315    })
2316}
2317
2318#[gpui::test]
2319async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2320    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2321    let fake_model = model.as_fake();
2322
2323    let mut events = thread
2324        .update(cx, |thread, cx| {
2325            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2326            thread.send(UserMessageId::new(), ["Hello!"], cx)
2327        })
2328        .unwrap();
2329    cx.run_until_parked();
2330
2331    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2332        fake_model.send_last_completion_stream_error(
2333            LanguageModelCompletionError::ServerOverloaded {
2334                provider: LanguageModelProviderName::new("Anthropic"),
2335                retry_after: Some(Duration::from_secs(3)),
2336            },
2337        );
2338        fake_model.end_last_completion_stream();
2339        cx.executor().advance_clock(Duration::from_secs(3));
2340        cx.run_until_parked();
2341    }
2342
2343    let mut errors = Vec::new();
2344    let mut retry_events = Vec::new();
2345    while let Some(event) = events.next().await {
2346        match event {
2347            Ok(ThreadEvent::Retry(retry_status)) => {
2348                retry_events.push(retry_status);
2349            }
2350            Ok(ThreadEvent::Stop(..)) => break,
2351            Err(error) => errors.push(error),
2352            _ => {}
2353        }
2354    }
2355
2356    assert_eq!(
2357        retry_events.len(),
2358        crate::thread::MAX_RETRY_ATTEMPTS as usize
2359    );
2360    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2361        assert_eq!(retry_events[i].attempt, i + 1);
2362    }
2363    assert_eq!(errors.len(), 1);
2364    let error = errors[0]
2365        .downcast_ref::<LanguageModelCompletionError>()
2366        .unwrap();
2367    assert!(matches!(
2368        error,
2369        LanguageModelCompletionError::ServerOverloaded { .. }
2370    ));
2371}
2372
2373/// Filters out the stop events for asserting against in tests
2374fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2375    result_events
2376        .into_iter()
2377        .filter_map(|event| match event.unwrap() {
2378            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2379            _ => None,
2380        })
2381        .collect()
2382}
2383
2384struct ThreadTest {
2385    model: Arc<dyn LanguageModel>,
2386    thread: Entity<Thread>,
2387    project_context: Entity<ProjectContext>,
2388    context_server_store: Entity<ContextServerStore>,
2389    fs: Arc<FakeFs>,
2390}
2391
2392enum TestModel {
2393    Sonnet4,
2394    Fake,
2395}
2396
2397impl TestModel {
2398    fn id(&self) -> LanguageModelId {
2399        match self {
2400            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2401            TestModel::Fake => unreachable!(),
2402        }
2403    }
2404}
2405
2406async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2407    cx.executor().allow_parking();
2408
2409    let fs = FakeFs::new(cx.background_executor.clone());
2410    fs.create_dir(paths::settings_file().parent().unwrap())
2411        .await
2412        .unwrap();
2413    fs.insert_file(
2414        paths::settings_file(),
2415        json!({
2416            "agent": {
2417                "default_profile": "test-profile",
2418                "profiles": {
2419                    "test-profile": {
2420                        "name": "Test Profile",
2421                        "tools": {
2422                            EchoTool::name(): true,
2423                            DelayTool::name(): true,
2424                            WordListTool::name(): true,
2425                            ToolRequiringPermission::name(): true,
2426                            InfiniteTool::name(): true,
2427                            ThinkingTool::name(): true,
2428                        }
2429                    }
2430                }
2431            }
2432        })
2433        .to_string()
2434        .into_bytes(),
2435    )
2436    .await;
2437
2438    cx.update(|cx| {
2439        settings::init(cx);
2440
2441        match model {
2442            TestModel::Fake => {}
2443            TestModel::Sonnet4 => {
2444                gpui_tokio::init(cx);
2445                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2446                cx.set_http_client(Arc::new(http_client));
2447                let client = Client::production(cx);
2448                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2449                language_model::init(client.clone(), cx);
2450                language_models::init(user_store, client.clone(), cx);
2451            }
2452        };
2453
2454        watch_settings(fs.clone(), cx);
2455    });
2456
2457    let templates = Templates::new();
2458
2459    fs.insert_tree(path!("/test"), json!({})).await;
2460    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2461
2462    let model = cx
2463        .update(|cx| {
2464            if let TestModel::Fake = model {
2465                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2466            } else {
2467                let model_id = model.id();
2468                let models = LanguageModelRegistry::read_global(cx);
2469                let model = models
2470                    .available_models(cx)
2471                    .find(|model| model.id() == model_id)
2472                    .unwrap();
2473
2474                let provider = models.provider(&model.provider_id()).unwrap();
2475                let authenticated = provider.authenticate(cx);
2476
2477                cx.spawn(async move |_cx| {
2478                    authenticated.await.unwrap();
2479                    model
2480                })
2481            }
2482        })
2483        .await;
2484
2485    let project_context = cx.new(|_cx| ProjectContext::default());
2486    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2487    let context_server_registry =
2488        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2489    let thread = cx.new(|cx| {
2490        Thread::new(
2491            project,
2492            project_context.clone(),
2493            context_server_registry,
2494            templates,
2495            Some(model.clone()),
2496            cx,
2497        )
2498    });
2499    ThreadTest {
2500        model,
2501        thread,
2502        project_context,
2503        context_server_store,
2504        fs,
2505    }
2506}
2507
2508#[cfg(test)]
2509#[ctor::ctor]
2510fn init_logger() {
2511    if std::env::var("RUST_LOG").is_ok() {
2512        env_logger::init();
2513    }
2514}
2515
2516fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2517    let fs = fs.clone();
2518    cx.spawn({
2519        async move |cx| {
2520            let mut new_settings_content_rx = settings::watch_config_file(
2521                cx.background_executor(),
2522                fs,
2523                paths::settings_file().clone(),
2524            );
2525
2526            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2527                cx.update(|cx| {
2528                    SettingsStore::update_global(cx, |settings, cx| {
2529                        settings.set_user_settings(&new_settings_content, cx)
2530                    })
2531                })
2532                .ok();
2533            }
2534        }
2535    })
2536    .detach();
2537}
2538
2539fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2540    completion
2541        .tools
2542        .iter()
2543        .map(|tool| tool.name.clone())
2544        .collect()
2545}
2546
2547fn setup_context_server(
2548    name: &'static str,
2549    tools: Vec<context_server::types::Tool>,
2550    context_server_store: &Entity<ContextServerStore>,
2551    cx: &mut TestAppContext,
2552) -> mpsc::UnboundedReceiver<(
2553    context_server::types::CallToolParams,
2554    oneshot::Sender<context_server::types::CallToolResponse>,
2555)> {
2556    cx.update(|cx| {
2557        let mut settings = ProjectSettings::get_global(cx).clone();
2558        settings.context_servers.insert(
2559            name.into(),
2560            project::project_settings::ContextServerSettings::Custom {
2561                enabled: true,
2562                command: ContextServerCommand {
2563                    path: "somebinary".into(),
2564                    args: Vec::new(),
2565                    env: None,
2566                    timeout: None,
2567                },
2568            },
2569        );
2570        ProjectSettings::override_global(settings, cx);
2571    });
2572
2573    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2574    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2575        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2576            context_server::types::InitializeResponse {
2577                protocol_version: context_server::types::ProtocolVersion(
2578                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2579                ),
2580                server_info: context_server::types::Implementation {
2581                    name: name.into(),
2582                    version: "1.0.0".to_string(),
2583                },
2584                capabilities: context_server::types::ServerCapabilities {
2585                    tools: Some(context_server::types::ToolsCapabilities {
2586                        list_changed: Some(true),
2587                    }),
2588                    ..Default::default()
2589                },
2590                meta: None,
2591            }
2592        })
2593        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2594            let tools = tools.clone();
2595            async move {
2596                context_server::types::ListToolsResponse {
2597                    tools,
2598                    next_cursor: None,
2599                    meta: None,
2600                }
2601            }
2602        })
2603        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2604            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2605            async move {
2606                let (response_tx, response_rx) = oneshot::channel();
2607                mcp_tool_calls_tx
2608                    .unbounded_send((params, response_tx))
2609                    .unwrap();
2610                response_rx.await.unwrap()
2611            }
2612        });
2613    context_server_store.update(cx, |store, cx| {
2614        store.start_server(
2615            Arc::new(ContextServer::new(
2616                ContextServerId(name.into()),
2617                Arc::new(fake_transport),
2618            )),
2619            cx,
2620        );
2621    });
2622    cx.run_until_parked();
2623    mcp_tool_calls_rx
2624}