mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true
 219        }]
 220    );
 221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 222        "Response to Message 1".into(),
 223    ));
 224    fake_model.end_last_completion_stream();
 225    cx.run_until_parked();
 226
 227    // Send another user message and verify only the latest is cached
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Message 2"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let completion = fake_model.pending_completions().pop().unwrap();
 236    assert_eq!(
 237        completion.messages[1..],
 238        vec![
 239            LanguageModelRequestMessage {
 240                role: Role::User,
 241                content: vec!["Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::Assistant,
 246                content: vec!["Response to Message 1".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::User,
 251                content: vec!["Message 2".into()],
 252                cache: true
 253            }
 254        ]
 255    );
 256    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 257        "Response to Message 2".into(),
 258    ));
 259    fake_model.end_last_completion_stream();
 260    cx.run_until_parked();
 261
 262    // Simulate a tool call and verify that the latest tool result is cached
 263    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 264    thread
 265        .update(cx, |thread, cx| {
 266            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 267        })
 268        .unwrap();
 269    cx.run_until_parked();
 270
 271    let tool_use = LanguageModelToolUse {
 272        id: "tool_1".into(),
 273        name: EchoTool::name().into(),
 274        raw_input: json!({"text": "test"}).to_string(),
 275        input: json!({"text": "test"}),
 276        is_input_complete: true,
 277    };
 278    fake_model
 279        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 280    fake_model.end_last_completion_stream();
 281    cx.run_until_parked();
 282
 283    let completion = fake_model.pending_completions().pop().unwrap();
 284    let tool_result = LanguageModelToolResult {
 285        tool_use_id: "tool_1".into(),
 286        tool_name: EchoTool::name().into(),
 287        is_error: false,
 288        content: "test".into(),
 289        output: Some("test".into()),
 290    };
 291    assert_eq!(
 292        completion.messages[1..],
 293        vec![
 294            LanguageModelRequestMessage {
 295                role: Role::User,
 296                content: vec!["Message 1".into()],
 297                cache: false
 298            },
 299            LanguageModelRequestMessage {
 300                role: Role::Assistant,
 301                content: vec!["Response to Message 1".into()],
 302                cache: false
 303            },
 304            LanguageModelRequestMessage {
 305                role: Role::User,
 306                content: vec!["Message 2".into()],
 307                cache: false
 308            },
 309            LanguageModelRequestMessage {
 310                role: Role::Assistant,
 311                content: vec!["Response to Message 2".into()],
 312                cache: false
 313            },
 314            LanguageModelRequestMessage {
 315                role: Role::User,
 316                content: vec!["Use the echo tool".into()],
 317                cache: false
 318            },
 319            LanguageModelRequestMessage {
 320                role: Role::Assistant,
 321                content: vec![MessageContent::ToolUse(tool_use)],
 322                cache: false
 323            },
 324            LanguageModelRequestMessage {
 325                role: Role::User,
 326                content: vec![MessageContent::ToolResult(tool_result)],
 327                cache: true
 328            }
 329        ]
 330    );
 331}
 332
 333#[gpui::test]
 334#[cfg_attr(not(feature = "e2e"), ignore)]
 335async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(EchoTool);
 342            thread.send(
 343                UserMessageId::new(),
 344                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 345                cx,
 346            )
 347        })
 348        .unwrap()
 349        .collect()
 350        .await;
 351    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 352
 353    // Test a tool calls that's likely to complete *after* streaming stops.
 354    let events = thread
 355        .update(cx, |thread, cx| {
 356            thread.remove_tool(&EchoTool::name());
 357            thread.add_tool(DelayTool);
 358            thread.send(
 359                UserMessageId::new(),
 360                [
 361                    "Now call the delay tool with 200ms.",
 362                    "When the timer goes off, then you echo the output of the tool.",
 363                ],
 364                cx,
 365            )
 366        })
 367        .unwrap()
 368        .collect()
 369        .await;
 370    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 371    thread.update(cx, |thread, _cx| {
 372        assert!(
 373            thread
 374                .last_message()
 375                .unwrap()
 376                .as_agent_message()
 377                .unwrap()
 378                .content
 379                .iter()
 380                .any(|content| {
 381                    if let AgentMessageContent::Text(text) = content {
 382                        text.contains("Ding")
 383                    } else {
 384                        false
 385                    }
 386                }),
 387            "{}",
 388            thread.to_markdown()
 389        );
 390    });
 391}
 392
 393#[gpui::test]
 394#[cfg_attr(not(feature = "e2e"), ignore)]
 395async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 396    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 397
 398    // Test a tool call that's likely to complete *before* streaming stops.
 399    let mut events = thread
 400        .update(cx, |thread, cx| {
 401            thread.add_tool(WordListTool);
 402            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 403        })
 404        .unwrap();
 405
 406    let mut saw_partial_tool_use = false;
 407    while let Some(event) = events.next().await {
 408        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 409            thread.update(cx, |thread, _cx| {
 410                // Look for a tool use in the thread's last message
 411                let message = thread.last_message().unwrap();
 412                let agent_message = message.as_agent_message().unwrap();
 413                let last_content = agent_message.content.last().unwrap();
 414                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 415                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 416                    if tool_call.status == acp::ToolCallStatus::Pending {
 417                        if !last_tool_use.is_input_complete
 418                            && last_tool_use.input.get("g").is_none()
 419                        {
 420                            saw_partial_tool_use = true;
 421                        }
 422                    } else {
 423                        last_tool_use
 424                            .input
 425                            .get("a")
 426                            .expect("'a' has streamed because input is now complete");
 427                        last_tool_use
 428                            .input
 429                            .get("g")
 430                            .expect("'g' has streamed because input is now complete");
 431                    }
 432                } else {
 433                    panic!("last content should be a tool use");
 434                }
 435            });
 436        }
 437    }
 438
 439    assert!(
 440        saw_partial_tool_use,
 441        "should see at least one partially streamed tool use in the history"
 442    );
 443}
 444
 445#[gpui::test]
 446async fn test_tool_authorization(cx: &mut TestAppContext) {
 447    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 448    let fake_model = model.as_fake();
 449
 450    let mut events = thread
 451        .update(cx, |thread, cx| {
 452            thread.add_tool(ToolRequiringPermission);
 453            thread.send(UserMessageId::new(), ["abc"], cx)
 454        })
 455        .unwrap();
 456    cx.run_until_parked();
 457    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 458        LanguageModelToolUse {
 459            id: "tool_id_1".into(),
 460            name: ToolRequiringPermission::name().into(),
 461            raw_input: "{}".into(),
 462            input: json!({}),
 463            is_input_complete: true,
 464        },
 465    ));
 466    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 467        LanguageModelToolUse {
 468            id: "tool_id_2".into(),
 469            name: ToolRequiringPermission::name().into(),
 470            raw_input: "{}".into(),
 471            input: json!({}),
 472            is_input_complete: true,
 473        },
 474    ));
 475    fake_model.end_last_completion_stream();
 476    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 477    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 478
 479    // Approve the first
 480    tool_call_auth_1
 481        .response
 482        .send(tool_call_auth_1.options[1].id.clone())
 483        .unwrap();
 484    cx.run_until_parked();
 485
 486    // Reject the second
 487    tool_call_auth_2
 488        .response
 489        .send(tool_call_auth_1.options[2].id.clone())
 490        .unwrap();
 491    cx.run_until_parked();
 492
 493    let completion = fake_model.pending_completions().pop().unwrap();
 494    let message = completion.messages.last().unwrap();
 495    assert_eq!(
 496        message.content,
 497        vec![
 498            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 499                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 500                tool_name: ToolRequiringPermission::name().into(),
 501                is_error: false,
 502                content: "Allowed".into(),
 503                output: Some("Allowed".into())
 504            }),
 505            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 506                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 507                tool_name: ToolRequiringPermission::name().into(),
 508                is_error: true,
 509                content: "Permission to run tool denied by user".into(),
 510                output: Some("Permission to run tool denied by user".into())
 511            })
 512        ]
 513    );
 514
 515    // Simulate yet another tool call.
 516    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 517        LanguageModelToolUse {
 518            id: "tool_id_3".into(),
 519            name: ToolRequiringPermission::name().into(),
 520            raw_input: "{}".into(),
 521            input: json!({}),
 522            is_input_complete: true,
 523        },
 524    ));
 525    fake_model.end_last_completion_stream();
 526
 527    // Respond by always allowing tools.
 528    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 529    tool_call_auth_3
 530        .response
 531        .send(tool_call_auth_3.options[0].id.clone())
 532        .unwrap();
 533    cx.run_until_parked();
 534    let completion = fake_model.pending_completions().pop().unwrap();
 535    let message = completion.messages.last().unwrap();
 536    assert_eq!(
 537        message.content,
 538        vec![language_model::MessageContent::ToolResult(
 539            LanguageModelToolResult {
 540                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 541                tool_name: ToolRequiringPermission::name().into(),
 542                is_error: false,
 543                content: "Allowed".into(),
 544                output: Some("Allowed".into())
 545            }
 546        )]
 547    );
 548
 549    // Simulate a final tool call, ensuring we don't trigger authorization.
 550    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 551        LanguageModelToolUse {
 552            id: "tool_id_4".into(),
 553            name: ToolRequiringPermission::name().into(),
 554            raw_input: "{}".into(),
 555            input: json!({}),
 556            is_input_complete: true,
 557        },
 558    ));
 559    fake_model.end_last_completion_stream();
 560    cx.run_until_parked();
 561    let completion = fake_model.pending_completions().pop().unwrap();
 562    let message = completion.messages.last().unwrap();
 563    assert_eq!(
 564        message.content,
 565        vec![language_model::MessageContent::ToolResult(
 566            LanguageModelToolResult {
 567                tool_use_id: "tool_id_4".into(),
 568                tool_name: ToolRequiringPermission::name().into(),
 569                is_error: false,
 570                content: "Allowed".into(),
 571                output: Some("Allowed".into())
 572            }
 573        )]
 574    );
 575}
 576
 577#[gpui::test]
 578async fn test_tool_hallucination(cx: &mut TestAppContext) {
 579    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 580    let fake_model = model.as_fake();
 581
 582    let mut events = thread
 583        .update(cx, |thread, cx| {
 584            thread.send(UserMessageId::new(), ["abc"], cx)
 585        })
 586        .unwrap();
 587    cx.run_until_parked();
 588    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 589        LanguageModelToolUse {
 590            id: "tool_id_1".into(),
 591            name: "nonexistent_tool".into(),
 592            raw_input: "{}".into(),
 593            input: json!({}),
 594            is_input_complete: true,
 595        },
 596    ));
 597    fake_model.end_last_completion_stream();
 598
 599    let tool_call = expect_tool_call(&mut events).await;
 600    assert_eq!(tool_call.title, "nonexistent_tool");
 601    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 602    let update = expect_tool_call_update_fields(&mut events).await;
 603    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 604}
 605
 606#[gpui::test]
 607async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 608    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 609    let fake_model = model.as_fake();
 610
 611    let events = thread
 612        .update(cx, |thread, cx| {
 613            thread.add_tool(EchoTool);
 614            thread.send(UserMessageId::new(), ["abc"], cx)
 615        })
 616        .unwrap();
 617    cx.run_until_parked();
 618    let tool_use = LanguageModelToolUse {
 619        id: "tool_id_1".into(),
 620        name: EchoTool::name().into(),
 621        raw_input: "{}".into(),
 622        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 623        is_input_complete: true,
 624    };
 625    fake_model
 626        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 627    fake_model.end_last_completion_stream();
 628
 629    cx.run_until_parked();
 630    let completion = fake_model.pending_completions().pop().unwrap();
 631    let tool_result = LanguageModelToolResult {
 632        tool_use_id: "tool_id_1".into(),
 633        tool_name: EchoTool::name().into(),
 634        is_error: false,
 635        content: "def".into(),
 636        output: Some("def".into()),
 637    };
 638    assert_eq!(
 639        completion.messages[1..],
 640        vec![
 641            LanguageModelRequestMessage {
 642                role: Role::User,
 643                content: vec!["abc".into()],
 644                cache: false
 645            },
 646            LanguageModelRequestMessage {
 647                role: Role::Assistant,
 648                content: vec![MessageContent::ToolUse(tool_use.clone())],
 649                cache: false
 650            },
 651            LanguageModelRequestMessage {
 652                role: Role::User,
 653                content: vec![MessageContent::ToolResult(tool_result.clone())],
 654                cache: true
 655            },
 656        ]
 657    );
 658
 659    // Simulate reaching tool use limit.
 660    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 661        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 662    ));
 663    fake_model.end_last_completion_stream();
 664    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 665    assert!(
 666        last_event
 667            .unwrap_err()
 668            .is::<language_model::ToolUseLimitReachedError>()
 669    );
 670
 671    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 672    cx.run_until_parked();
 673    let completion = fake_model.pending_completions().pop().unwrap();
 674    assert_eq!(
 675        completion.messages[1..],
 676        vec![
 677            LanguageModelRequestMessage {
 678                role: Role::User,
 679                content: vec!["abc".into()],
 680                cache: false
 681            },
 682            LanguageModelRequestMessage {
 683                role: Role::Assistant,
 684                content: vec![MessageContent::ToolUse(tool_use)],
 685                cache: false
 686            },
 687            LanguageModelRequestMessage {
 688                role: Role::User,
 689                content: vec![MessageContent::ToolResult(tool_result)],
 690                cache: false
 691            },
 692            LanguageModelRequestMessage {
 693                role: Role::User,
 694                content: vec!["Continue where you left off".into()],
 695                cache: true
 696            }
 697        ]
 698    );
 699
 700    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 701    fake_model.end_last_completion_stream();
 702    events.collect::<Vec<_>>().await;
 703    thread.read_with(cx, |thread, _cx| {
 704        assert_eq!(
 705            thread.last_message().unwrap().to_markdown(),
 706            indoc! {"
 707                ## Assistant
 708
 709                Done
 710            "}
 711        )
 712    });
 713}
 714
 715#[gpui::test]
 716async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 717    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 718    let fake_model = model.as_fake();
 719
 720    let events = thread
 721        .update(cx, |thread, cx| {
 722            thread.add_tool(EchoTool);
 723            thread.send(UserMessageId::new(), ["abc"], cx)
 724        })
 725        .unwrap();
 726    cx.run_until_parked();
 727
 728    let tool_use = LanguageModelToolUse {
 729        id: "tool_id_1".into(),
 730        name: EchoTool::name().into(),
 731        raw_input: "{}".into(),
 732        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 733        is_input_complete: true,
 734    };
 735    let tool_result = LanguageModelToolResult {
 736        tool_use_id: "tool_id_1".into(),
 737        tool_name: EchoTool::name().into(),
 738        is_error: false,
 739        content: "def".into(),
 740        output: Some("def".into()),
 741    };
 742    fake_model
 743        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 744    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 745        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 746    ));
 747    fake_model.end_last_completion_stream();
 748    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 749    assert!(
 750        last_event
 751            .unwrap_err()
 752            .is::<language_model::ToolUseLimitReachedError>()
 753    );
 754
 755    thread
 756        .update(cx, |thread, cx| {
 757            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 758        })
 759        .unwrap();
 760    cx.run_until_parked();
 761    let completion = fake_model.pending_completions().pop().unwrap();
 762    assert_eq!(
 763        completion.messages[1..],
 764        vec![
 765            LanguageModelRequestMessage {
 766                role: Role::User,
 767                content: vec!["abc".into()],
 768                cache: false
 769            },
 770            LanguageModelRequestMessage {
 771                role: Role::Assistant,
 772                content: vec![MessageContent::ToolUse(tool_use)],
 773                cache: false
 774            },
 775            LanguageModelRequestMessage {
 776                role: Role::User,
 777                content: vec![MessageContent::ToolResult(tool_result)],
 778                cache: false
 779            },
 780            LanguageModelRequestMessage {
 781                role: Role::User,
 782                content: vec!["ghi".into()],
 783                cache: true
 784            }
 785        ]
 786    );
 787}
 788
 789async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 790    let event = events
 791        .next()
 792        .await
 793        .expect("no tool call authorization event received")
 794        .unwrap();
 795    match event {
 796        ThreadEvent::ToolCall(tool_call) => tool_call,
 797        event => {
 798            panic!("Unexpected event {event:?}");
 799        }
 800    }
 801}
 802
 803async fn expect_tool_call_update_fields(
 804    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 805) -> acp::ToolCallUpdate {
 806    let event = events
 807        .next()
 808        .await
 809        .expect("no tool call authorization event received")
 810        .unwrap();
 811    match event {
 812        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 813        event => {
 814            panic!("Unexpected event {event:?}");
 815        }
 816    }
 817}
 818
 819async fn next_tool_call_authorization(
 820    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 821) -> ToolCallAuthorization {
 822    loop {
 823        let event = events
 824            .next()
 825            .await
 826            .expect("no tool call authorization event received")
 827            .unwrap();
 828        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 829            let permission_kinds = tool_call_authorization
 830                .options
 831                .iter()
 832                .map(|o| o.kind)
 833                .collect::<Vec<_>>();
 834            assert_eq!(
 835                permission_kinds,
 836                vec![
 837                    acp::PermissionOptionKind::AllowAlways,
 838                    acp::PermissionOptionKind::AllowOnce,
 839                    acp::PermissionOptionKind::RejectOnce,
 840                ]
 841            );
 842            return tool_call_authorization;
 843        }
 844    }
 845}
 846
 847#[gpui::test]
 848#[cfg_attr(not(feature = "e2e"), ignore)]
 849async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 850    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 851
 852    // Test concurrent tool calls with different delay times
 853    let events = thread
 854        .update(cx, |thread, cx| {
 855            thread.add_tool(DelayTool);
 856            thread.send(
 857                UserMessageId::new(),
 858                [
 859                    "Call the delay tool twice in the same message.",
 860                    "Once with 100ms. Once with 300ms.",
 861                    "When both timers are complete, describe the outputs.",
 862                ],
 863                cx,
 864            )
 865        })
 866        .unwrap()
 867        .collect()
 868        .await;
 869
 870    let stop_reasons = stop_events(events);
 871    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 872
 873    thread.update(cx, |thread, _cx| {
 874        let last_message = thread.last_message().unwrap();
 875        let agent_message = last_message.as_agent_message().unwrap();
 876        let text = agent_message
 877            .content
 878            .iter()
 879            .filter_map(|content| {
 880                if let AgentMessageContent::Text(text) = content {
 881                    Some(text.as_str())
 882                } else {
 883                    None
 884                }
 885            })
 886            .collect::<String>();
 887
 888        assert!(text.contains("Ding"));
 889    });
 890}
 891
 892#[gpui::test]
 893async fn test_profiles(cx: &mut TestAppContext) {
 894    let ThreadTest {
 895        model, thread, fs, ..
 896    } = setup(cx, TestModel::Fake).await;
 897    let fake_model = model.as_fake();
 898
 899    thread.update(cx, |thread, _cx| {
 900        thread.add_tool(DelayTool);
 901        thread.add_tool(EchoTool);
 902        thread.add_tool(InfiniteTool);
 903    });
 904
 905    // Override profiles and wait for settings to be loaded.
 906    fs.insert_file(
 907        paths::settings_file(),
 908        json!({
 909            "agent": {
 910                "profiles": {
 911                    "test-1": {
 912                        "name": "Test Profile 1",
 913                        "tools": {
 914                            EchoTool::name(): true,
 915                            DelayTool::name(): true,
 916                        }
 917                    },
 918                    "test-2": {
 919                        "name": "Test Profile 2",
 920                        "tools": {
 921                            InfiniteTool::name(): true,
 922                        }
 923                    }
 924                }
 925            }
 926        })
 927        .to_string()
 928        .into_bytes(),
 929    )
 930    .await;
 931    cx.run_until_parked();
 932
 933    // Test that test-1 profile (default) has echo and delay tools
 934    thread
 935        .update(cx, |thread, cx| {
 936            thread.set_profile(AgentProfileId("test-1".into()), cx);
 937            thread.send(UserMessageId::new(), ["test"], cx)
 938        })
 939        .unwrap();
 940    cx.run_until_parked();
 941
 942    let mut pending_completions = fake_model.pending_completions();
 943    assert_eq!(pending_completions.len(), 1);
 944    let completion = pending_completions.pop().unwrap();
 945    let tool_names: Vec<String> = completion
 946        .tools
 947        .iter()
 948        .map(|tool| tool.name.clone())
 949        .collect();
 950    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 951    fake_model.end_last_completion_stream();
 952
 953    // Switch to test-2 profile, and verify that it has only the infinite tool.
 954    thread
 955        .update(cx, |thread, cx| {
 956            thread.set_profile(AgentProfileId("test-2".into()), cx);
 957            thread.send(UserMessageId::new(), ["test2"], cx)
 958        })
 959        .unwrap();
 960    cx.run_until_parked();
 961    let mut pending_completions = fake_model.pending_completions();
 962    assert_eq!(pending_completions.len(), 1);
 963    let completion = pending_completions.pop().unwrap();
 964    let tool_names: Vec<String> = completion
 965        .tools
 966        .iter()
 967        .map(|tool| tool.name.clone())
 968        .collect();
 969    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 970}
 971
 972#[gpui::test]
 973async fn test_mcp_tools(cx: &mut TestAppContext) {
 974    let ThreadTest {
 975        model,
 976        thread,
 977        context_server_store,
 978        fs,
 979        ..
 980    } = setup(cx, TestModel::Fake).await;
 981    let fake_model = model.as_fake();
 982
 983    // Override profiles and wait for settings to be loaded.
 984    fs.insert_file(
 985        paths::settings_file(),
 986        json!({
 987            "agent": {
 988                "always_allow_tool_actions": true,
 989                "profiles": {
 990                    "test": {
 991                        "name": "Test Profile",
 992                        "enable_all_context_servers": true,
 993                        "tools": {
 994                            EchoTool::name(): true,
 995                        }
 996                    },
 997                }
 998            }
 999        })
1000        .to_string()
1001        .into_bytes(),
1002    )
1003    .await;
1004    cx.run_until_parked();
1005    thread.update(cx, |thread, cx| {
1006        thread.set_profile(AgentProfileId("test".into()), cx)
1007    });
1008
1009    let mut mcp_tool_calls = setup_context_server(
1010        "test_server",
1011        vec![context_server::types::Tool {
1012            name: "echo".into(),
1013            description: None,
1014            input_schema: serde_json::to_value(EchoTool::input_schema(
1015                LanguageModelToolSchemaFormat::JsonSchema,
1016            ))
1017            .unwrap(),
1018            output_schema: None,
1019            annotations: None,
1020        }],
1021        &context_server_store,
1022        cx,
1023    );
1024
1025    let events = thread.update(cx, |thread, cx| {
1026        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1027    });
1028    cx.run_until_parked();
1029
1030    // Simulate the model calling the MCP tool.
1031    let completion = fake_model.pending_completions().pop().unwrap();
1032    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1033    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1034        LanguageModelToolUse {
1035            id: "tool_1".into(),
1036            name: "echo".into(),
1037            raw_input: json!({"text": "test"}).to_string(),
1038            input: json!({"text": "test"}),
1039            is_input_complete: true,
1040        },
1041    ));
1042    fake_model.end_last_completion_stream();
1043    cx.run_until_parked();
1044
1045    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1046    assert_eq!(tool_call_params.name, "echo");
1047    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1048    tool_call_response
1049        .send(context_server::types::CallToolResponse {
1050            content: vec![context_server::types::ToolResponseContent::Text {
1051                text: "test".into(),
1052            }],
1053            is_error: None,
1054            meta: None,
1055            structured_content: None,
1056        })
1057        .unwrap();
1058    cx.run_until_parked();
1059
1060    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1061    fake_model.send_last_completion_stream_text_chunk("Done!");
1062    fake_model.end_last_completion_stream();
1063    events.collect::<Vec<_>>().await;
1064
1065    // Send again after adding the echo tool, ensuring the name collision is resolved.
1066    let events = thread.update(cx, |thread, cx| {
1067        thread.add_tool(EchoTool);
1068        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1069    });
1070    cx.run_until_parked();
1071    let completion = fake_model.pending_completions().pop().unwrap();
1072    assert_eq!(
1073        tool_names_for_completion(&completion),
1074        vec!["echo", "test_server_echo"]
1075    );
1076    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1077        LanguageModelToolUse {
1078            id: "tool_2".into(),
1079            name: "test_server_echo".into(),
1080            raw_input: json!({"text": "mcp"}).to_string(),
1081            input: json!({"text": "mcp"}),
1082            is_input_complete: true,
1083        },
1084    ));
1085    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1086        LanguageModelToolUse {
1087            id: "tool_3".into(),
1088            name: "echo".into(),
1089            raw_input: json!({"text": "native"}).to_string(),
1090            input: json!({"text": "native"}),
1091            is_input_complete: true,
1092        },
1093    ));
1094    fake_model.end_last_completion_stream();
1095    cx.run_until_parked();
1096
1097    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1098    assert_eq!(tool_call_params.name, "echo");
1099    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1100    tool_call_response
1101        .send(context_server::types::CallToolResponse {
1102            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1103            is_error: None,
1104            meta: None,
1105            structured_content: None,
1106        })
1107        .unwrap();
1108    cx.run_until_parked();
1109
1110    // Ensure the tool results were inserted with the correct names.
1111    let completion = fake_model.pending_completions().pop().unwrap();
1112    assert_eq!(
1113        completion.messages.last().unwrap().content,
1114        vec![
1115            MessageContent::ToolResult(LanguageModelToolResult {
1116                tool_use_id: "tool_3".into(),
1117                tool_name: "echo".into(),
1118                is_error: false,
1119                content: "native".into(),
1120                output: Some("native".into()),
1121            },),
1122            MessageContent::ToolResult(LanguageModelToolResult {
1123                tool_use_id: "tool_2".into(),
1124                tool_name: "test_server_echo".into(),
1125                is_error: false,
1126                content: "mcp".into(),
1127                output: Some("mcp".into()),
1128            },),
1129        ]
1130    );
1131    fake_model.end_last_completion_stream();
1132    events.collect::<Vec<_>>().await;
1133}
1134
1135#[gpui::test]
1136async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1137    let ThreadTest {
1138        model,
1139        thread,
1140        context_server_store,
1141        fs,
1142        ..
1143    } = setup(cx, TestModel::Fake).await;
1144    let fake_model = model.as_fake();
1145
1146    // Set up a profile with all tools enabled
1147    fs.insert_file(
1148        paths::settings_file(),
1149        json!({
1150            "agent": {
1151                "profiles": {
1152                    "test": {
1153                        "name": "Test Profile",
1154                        "enable_all_context_servers": true,
1155                        "tools": {
1156                            EchoTool::name(): true,
1157                            DelayTool::name(): true,
1158                            WordListTool::name(): true,
1159                            ToolRequiringPermission::name(): true,
1160                            InfiniteTool::name(): true,
1161                        }
1162                    },
1163                }
1164            }
1165        })
1166        .to_string()
1167        .into_bytes(),
1168    )
1169    .await;
1170    cx.run_until_parked();
1171
1172    thread.update(cx, |thread, cx| {
1173        thread.set_profile(AgentProfileId("test".into()), cx);
1174        thread.add_tool(EchoTool);
1175        thread.add_tool(DelayTool);
1176        thread.add_tool(WordListTool);
1177        thread.add_tool(ToolRequiringPermission);
1178        thread.add_tool(InfiniteTool);
1179    });
1180
1181    // Set up multiple context servers with some overlapping tool names
1182    let _server1_calls = setup_context_server(
1183        "xxx",
1184        vec![
1185            context_server::types::Tool {
1186                name: "echo".into(), // Conflicts with native EchoTool
1187                description: None,
1188                input_schema: serde_json::to_value(EchoTool::input_schema(
1189                    LanguageModelToolSchemaFormat::JsonSchema,
1190                ))
1191                .unwrap(),
1192                output_schema: None,
1193                annotations: None,
1194            },
1195            context_server::types::Tool {
1196                name: "unique_tool_1".into(),
1197                description: None,
1198                input_schema: json!({"type": "object", "properties": {}}),
1199                output_schema: None,
1200                annotations: None,
1201            },
1202        ],
1203        &context_server_store,
1204        cx,
1205    );
1206
1207    let _server2_calls = setup_context_server(
1208        "yyy",
1209        vec![
1210            context_server::types::Tool {
1211                name: "echo".into(), // Also conflicts with native EchoTool
1212                description: None,
1213                input_schema: serde_json::to_value(EchoTool::input_schema(
1214                    LanguageModelToolSchemaFormat::JsonSchema,
1215                ))
1216                .unwrap(),
1217                output_schema: None,
1218                annotations: None,
1219            },
1220            context_server::types::Tool {
1221                name: "unique_tool_2".into(),
1222                description: None,
1223                input_schema: json!({"type": "object", "properties": {}}),
1224                output_schema: None,
1225                annotations: None,
1226            },
1227            context_server::types::Tool {
1228                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1229                description: None,
1230                input_schema: json!({"type": "object", "properties": {}}),
1231                output_schema: None,
1232                annotations: None,
1233            },
1234            context_server::types::Tool {
1235                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1236                description: None,
1237                input_schema: json!({"type": "object", "properties": {}}),
1238                output_schema: None,
1239                annotations: None,
1240            },
1241        ],
1242        &context_server_store,
1243        cx,
1244    );
1245    let _server3_calls = setup_context_server(
1246        "zzz",
1247        vec![
1248            context_server::types::Tool {
1249                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1250                description: None,
1251                input_schema: json!({"type": "object", "properties": {}}),
1252                output_schema: None,
1253                annotations: None,
1254            },
1255            context_server::types::Tool {
1256                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1257                description: None,
1258                input_schema: json!({"type": "object", "properties": {}}),
1259                output_schema: None,
1260                annotations: None,
1261            },
1262            context_server::types::Tool {
1263                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1264                description: None,
1265                input_schema: json!({"type": "object", "properties": {}}),
1266                output_schema: None,
1267                annotations: None,
1268            },
1269        ],
1270        &context_server_store,
1271        cx,
1272    );
1273
1274    thread
1275        .update(cx, |thread, cx| {
1276            thread.send(UserMessageId::new(), ["Go"], cx)
1277        })
1278        .unwrap();
1279    cx.run_until_parked();
1280    let completion = fake_model.pending_completions().pop().unwrap();
1281    assert_eq!(
1282        tool_names_for_completion(&completion),
1283        vec![
1284            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1285            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1286            "delay",
1287            "echo",
1288            "infinite",
1289            "tool_requiring_permission",
1290            "unique_tool_1",
1291            "unique_tool_2",
1292            "word_list",
1293            "xxx_echo",
1294            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1295            "yyy_echo",
1296            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1297        ]
1298    );
1299}
1300
1301#[gpui::test]
1302#[cfg_attr(not(feature = "e2e"), ignore)]
1303async fn test_cancellation(cx: &mut TestAppContext) {
1304    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1305
1306    let mut events = thread
1307        .update(cx, |thread, cx| {
1308            thread.add_tool(InfiniteTool);
1309            thread.add_tool(EchoTool);
1310            thread.send(
1311                UserMessageId::new(),
1312                ["Call the echo tool, then call the infinite tool, then explain their output"],
1313                cx,
1314            )
1315        })
1316        .unwrap();
1317
1318    // Wait until both tools are called.
1319    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1320    let mut echo_id = None;
1321    let mut echo_completed = false;
1322    while let Some(event) = events.next().await {
1323        match event.unwrap() {
1324            ThreadEvent::ToolCall(tool_call) => {
1325                assert_eq!(tool_call.title, expected_tools.remove(0));
1326                if tool_call.title == "Echo" {
1327                    echo_id = Some(tool_call.id);
1328                }
1329            }
1330            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1331                acp::ToolCallUpdate {
1332                    id,
1333                    fields:
1334                        acp::ToolCallUpdateFields {
1335                            status: Some(acp::ToolCallStatus::Completed),
1336                            ..
1337                        },
1338                    meta: None,
1339                },
1340            )) if Some(&id) == echo_id.as_ref() => {
1341                echo_completed = true;
1342            }
1343            _ => {}
1344        }
1345
1346        if expected_tools.is_empty() && echo_completed {
1347            break;
1348        }
1349    }
1350
1351    // Cancel the current send and ensure that the event stream is closed, even
1352    // if one of the tools is still running.
1353    thread.update(cx, |thread, cx| thread.cancel(cx));
1354    let events = events.collect::<Vec<_>>().await;
1355    let last_event = events.last();
1356    assert!(
1357        matches!(
1358            last_event,
1359            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1360        ),
1361        "unexpected event {last_event:?}"
1362    );
1363
1364    // Ensure we can still send a new message after cancellation.
1365    let events = thread
1366        .update(cx, |thread, cx| {
1367            thread.send(
1368                UserMessageId::new(),
1369                ["Testing: reply with 'Hello' then stop."],
1370                cx,
1371            )
1372        })
1373        .unwrap()
1374        .collect::<Vec<_>>()
1375        .await;
1376    thread.update(cx, |thread, _cx| {
1377        let message = thread.last_message().unwrap();
1378        let agent_message = message.as_agent_message().unwrap();
1379        assert_eq!(
1380            agent_message.content,
1381            vec![AgentMessageContent::Text("Hello".to_string())]
1382        );
1383    });
1384    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1385}
1386
1387#[gpui::test]
1388async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1389    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1390    let fake_model = model.as_fake();
1391
1392    let events_1 = thread
1393        .update(cx, |thread, cx| {
1394            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1395        })
1396        .unwrap();
1397    cx.run_until_parked();
1398    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1399    cx.run_until_parked();
1400
1401    let events_2 = thread
1402        .update(cx, |thread, cx| {
1403            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1404        })
1405        .unwrap();
1406    cx.run_until_parked();
1407    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1408    fake_model
1409        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1410    fake_model.end_last_completion_stream();
1411
1412    let events_1 = events_1.collect::<Vec<_>>().await;
1413    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1414    let events_2 = events_2.collect::<Vec<_>>().await;
1415    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1416}
1417
1418#[gpui::test]
1419async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1420    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1421    let fake_model = model.as_fake();
1422
1423    let events_1 = thread
1424        .update(cx, |thread, cx| {
1425            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1426        })
1427        .unwrap();
1428    cx.run_until_parked();
1429    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1430    fake_model
1431        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1432    fake_model.end_last_completion_stream();
1433    let events_1 = events_1.collect::<Vec<_>>().await;
1434
1435    let events_2 = thread
1436        .update(cx, |thread, cx| {
1437            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1438        })
1439        .unwrap();
1440    cx.run_until_parked();
1441    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1442    fake_model
1443        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1444    fake_model.end_last_completion_stream();
1445    let events_2 = events_2.collect::<Vec<_>>().await;
1446
1447    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1448    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1449}
1450
1451#[gpui::test]
1452async fn test_refusal(cx: &mut TestAppContext) {
1453    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1454    let fake_model = model.as_fake();
1455
1456    let events = thread
1457        .update(cx, |thread, cx| {
1458            thread.send(UserMessageId::new(), ["Hello"], cx)
1459        })
1460        .unwrap();
1461    cx.run_until_parked();
1462    thread.read_with(cx, |thread, _| {
1463        assert_eq!(
1464            thread.to_markdown(),
1465            indoc! {"
1466                ## User
1467
1468                Hello
1469            "}
1470        );
1471    });
1472
1473    fake_model.send_last_completion_stream_text_chunk("Hey!");
1474    cx.run_until_parked();
1475    thread.read_with(cx, |thread, _| {
1476        assert_eq!(
1477            thread.to_markdown(),
1478            indoc! {"
1479                ## User
1480
1481                Hello
1482
1483                ## Assistant
1484
1485                Hey!
1486            "}
1487        );
1488    });
1489
1490    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1491    fake_model
1492        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1493    let events = events.collect::<Vec<_>>().await;
1494    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1495    thread.read_with(cx, |thread, _| {
1496        assert_eq!(thread.to_markdown(), "");
1497    });
1498}
1499
1500#[gpui::test]
1501async fn test_truncate_first_message(cx: &mut TestAppContext) {
1502    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1503    let fake_model = model.as_fake();
1504
1505    let message_id = UserMessageId::new();
1506    thread
1507        .update(cx, |thread, cx| {
1508            thread.send(message_id.clone(), ["Hello"], cx)
1509        })
1510        .unwrap();
1511    cx.run_until_parked();
1512    thread.read_with(cx, |thread, _| {
1513        assert_eq!(
1514            thread.to_markdown(),
1515            indoc! {"
1516                ## User
1517
1518                Hello
1519            "}
1520        );
1521        assert_eq!(thread.latest_token_usage(), None);
1522    });
1523
1524    fake_model.send_last_completion_stream_text_chunk("Hey!");
1525    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1526        language_model::TokenUsage {
1527            input_tokens: 32_000,
1528            output_tokens: 16_000,
1529            cache_creation_input_tokens: 0,
1530            cache_read_input_tokens: 0,
1531        },
1532    ));
1533    cx.run_until_parked();
1534    thread.read_with(cx, |thread, _| {
1535        assert_eq!(
1536            thread.to_markdown(),
1537            indoc! {"
1538                ## User
1539
1540                Hello
1541
1542                ## Assistant
1543
1544                Hey!
1545            "}
1546        );
1547        assert_eq!(
1548            thread.latest_token_usage(),
1549            Some(acp_thread::TokenUsage {
1550                used_tokens: 32_000 + 16_000,
1551                max_tokens: 1_000_000,
1552            })
1553        );
1554    });
1555
1556    thread
1557        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1558        .unwrap();
1559    cx.run_until_parked();
1560    thread.read_with(cx, |thread, _| {
1561        assert_eq!(thread.to_markdown(), "");
1562        assert_eq!(thread.latest_token_usage(), None);
1563    });
1564
1565    // Ensure we can still send a new message after truncation.
1566    thread
1567        .update(cx, |thread, cx| {
1568            thread.send(UserMessageId::new(), ["Hi"], cx)
1569        })
1570        .unwrap();
1571    thread.update(cx, |thread, _cx| {
1572        assert_eq!(
1573            thread.to_markdown(),
1574            indoc! {"
1575                ## User
1576
1577                Hi
1578            "}
1579        );
1580    });
1581    cx.run_until_parked();
1582    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1583    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1584        language_model::TokenUsage {
1585            input_tokens: 40_000,
1586            output_tokens: 20_000,
1587            cache_creation_input_tokens: 0,
1588            cache_read_input_tokens: 0,
1589        },
1590    ));
1591    cx.run_until_parked();
1592    thread.read_with(cx, |thread, _| {
1593        assert_eq!(
1594            thread.to_markdown(),
1595            indoc! {"
1596                ## User
1597
1598                Hi
1599
1600                ## Assistant
1601
1602                Ahoy!
1603            "}
1604        );
1605
1606        assert_eq!(
1607            thread.latest_token_usage(),
1608            Some(acp_thread::TokenUsage {
1609                used_tokens: 40_000 + 20_000,
1610                max_tokens: 1_000_000,
1611            })
1612        );
1613    });
1614}
1615
1616#[gpui::test]
1617async fn test_truncate_second_message(cx: &mut TestAppContext) {
1618    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1619    let fake_model = model.as_fake();
1620
1621    thread
1622        .update(cx, |thread, cx| {
1623            thread.send(UserMessageId::new(), ["Message 1"], cx)
1624        })
1625        .unwrap();
1626    cx.run_until_parked();
1627    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1628    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1629        language_model::TokenUsage {
1630            input_tokens: 32_000,
1631            output_tokens: 16_000,
1632            cache_creation_input_tokens: 0,
1633            cache_read_input_tokens: 0,
1634        },
1635    ));
1636    fake_model.end_last_completion_stream();
1637    cx.run_until_parked();
1638
1639    let assert_first_message_state = |cx: &mut TestAppContext| {
1640        thread.clone().read_with(cx, |thread, _| {
1641            assert_eq!(
1642                thread.to_markdown(),
1643                indoc! {"
1644                    ## User
1645
1646                    Message 1
1647
1648                    ## Assistant
1649
1650                    Message 1 response
1651                "}
1652            );
1653
1654            assert_eq!(
1655                thread.latest_token_usage(),
1656                Some(acp_thread::TokenUsage {
1657                    used_tokens: 32_000 + 16_000,
1658                    max_tokens: 1_000_000,
1659                })
1660            );
1661        });
1662    };
1663
1664    assert_first_message_state(cx);
1665
1666    let second_message_id = UserMessageId::new();
1667    thread
1668        .update(cx, |thread, cx| {
1669            thread.send(second_message_id.clone(), ["Message 2"], cx)
1670        })
1671        .unwrap();
1672    cx.run_until_parked();
1673
1674    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1675    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1676        language_model::TokenUsage {
1677            input_tokens: 40_000,
1678            output_tokens: 20_000,
1679            cache_creation_input_tokens: 0,
1680            cache_read_input_tokens: 0,
1681        },
1682    ));
1683    fake_model.end_last_completion_stream();
1684    cx.run_until_parked();
1685
1686    thread.read_with(cx, |thread, _| {
1687        assert_eq!(
1688            thread.to_markdown(),
1689            indoc! {"
1690                ## User
1691
1692                Message 1
1693
1694                ## Assistant
1695
1696                Message 1 response
1697
1698                ## User
1699
1700                Message 2
1701
1702                ## Assistant
1703
1704                Message 2 response
1705            "}
1706        );
1707
1708        assert_eq!(
1709            thread.latest_token_usage(),
1710            Some(acp_thread::TokenUsage {
1711                used_tokens: 40_000 + 20_000,
1712                max_tokens: 1_000_000,
1713            })
1714        );
1715    });
1716
1717    thread
1718        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1719        .unwrap();
1720    cx.run_until_parked();
1721
1722    assert_first_message_state(cx);
1723}
1724
1725#[gpui::test]
1726async fn test_title_generation(cx: &mut TestAppContext) {
1727    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1728    let fake_model = model.as_fake();
1729
1730    let summary_model = Arc::new(FakeLanguageModel::default());
1731    thread.update(cx, |thread, cx| {
1732        thread.set_summarization_model(Some(summary_model.clone()), cx)
1733    });
1734
1735    let send = thread
1736        .update(cx, |thread, cx| {
1737            thread.send(UserMessageId::new(), ["Hello"], cx)
1738        })
1739        .unwrap();
1740    cx.run_until_parked();
1741
1742    fake_model.send_last_completion_stream_text_chunk("Hey!");
1743    fake_model.end_last_completion_stream();
1744    cx.run_until_parked();
1745    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1746
1747    // Ensure the summary model has been invoked to generate a title.
1748    summary_model.send_last_completion_stream_text_chunk("Hello ");
1749    summary_model.send_last_completion_stream_text_chunk("world\nG");
1750    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1751    summary_model.end_last_completion_stream();
1752    send.collect::<Vec<_>>().await;
1753    cx.run_until_parked();
1754    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1755
1756    // Send another message, ensuring no title is generated this time.
1757    let send = thread
1758        .update(cx, |thread, cx| {
1759            thread.send(UserMessageId::new(), ["Hello again"], cx)
1760        })
1761        .unwrap();
1762    cx.run_until_parked();
1763    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1764    fake_model.end_last_completion_stream();
1765    cx.run_until_parked();
1766    assert_eq!(summary_model.pending_completions(), Vec::new());
1767    send.collect::<Vec<_>>().await;
1768    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1769}
1770
1771#[gpui::test]
1772async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1773    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1774    let fake_model = model.as_fake();
1775
1776    let _events = thread
1777        .update(cx, |thread, cx| {
1778            thread.add_tool(ToolRequiringPermission);
1779            thread.add_tool(EchoTool);
1780            thread.send(UserMessageId::new(), ["Hey!"], cx)
1781        })
1782        .unwrap();
1783    cx.run_until_parked();
1784
1785    let permission_tool_use = LanguageModelToolUse {
1786        id: "tool_id_1".into(),
1787        name: ToolRequiringPermission::name().into(),
1788        raw_input: "{}".into(),
1789        input: json!({}),
1790        is_input_complete: true,
1791    };
1792    let echo_tool_use = LanguageModelToolUse {
1793        id: "tool_id_2".into(),
1794        name: EchoTool::name().into(),
1795        raw_input: json!({"text": "test"}).to_string(),
1796        input: json!({"text": "test"}),
1797        is_input_complete: true,
1798    };
1799    fake_model.send_last_completion_stream_text_chunk("Hi!");
1800    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1801        permission_tool_use,
1802    ));
1803    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1804        echo_tool_use.clone(),
1805    ));
1806    fake_model.end_last_completion_stream();
1807    cx.run_until_parked();
1808
1809    // Ensure pending tools are skipped when building a request.
1810    let request = thread
1811        .read_with(cx, |thread, cx| {
1812            thread.build_completion_request(CompletionIntent::EditFile, cx)
1813        })
1814        .unwrap();
1815    assert_eq!(
1816        request.messages[1..],
1817        vec![
1818            LanguageModelRequestMessage {
1819                role: Role::User,
1820                content: vec!["Hey!".into()],
1821                cache: true
1822            },
1823            LanguageModelRequestMessage {
1824                role: Role::Assistant,
1825                content: vec![
1826                    MessageContent::Text("Hi!".into()),
1827                    MessageContent::ToolUse(echo_tool_use.clone())
1828                ],
1829                cache: false
1830            },
1831            LanguageModelRequestMessage {
1832                role: Role::User,
1833                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1834                    tool_use_id: echo_tool_use.id.clone(),
1835                    tool_name: echo_tool_use.name,
1836                    is_error: false,
1837                    content: "test".into(),
1838                    output: Some("test".into())
1839                })],
1840                cache: false
1841            },
1842        ],
1843    );
1844}
1845
1846#[gpui::test]
1847async fn test_agent_connection(cx: &mut TestAppContext) {
1848    cx.update(settings::init);
1849    let templates = Templates::new();
1850
1851    // Initialize language model system with test provider
1852    cx.update(|cx| {
1853        gpui_tokio::init(cx);
1854
1855        let http_client = FakeHttpClient::with_404_response();
1856        let clock = Arc::new(clock::FakeSystemClock::new());
1857        let client = Client::new(clock, http_client, cx);
1858        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1859        language_model::init(client.clone(), cx);
1860        language_models::init(user_store, client.clone(), cx);
1861        LanguageModelRegistry::test(cx);
1862    });
1863    cx.executor().forbid_parking();
1864
1865    // Create a project for new_thread
1866    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1867    fake_fs.insert_tree(path!("/test"), json!({})).await;
1868    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1869    let cwd = Path::new("/test");
1870    let text_thread_store =
1871        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1872    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1873
1874    // Create agent and connection
1875    let agent = NativeAgent::new(
1876        project.clone(),
1877        history_store,
1878        templates.clone(),
1879        None,
1880        fake_fs.clone(),
1881        &mut cx.to_async(),
1882    )
1883    .await
1884    .unwrap();
1885    let connection = NativeAgentConnection(agent.clone());
1886
1887    // Create a thread using new_thread
1888    let connection_rc = Rc::new(connection.clone());
1889    let acp_thread = cx
1890        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1891        .await
1892        .expect("new_thread should succeed");
1893
1894    // Get the session_id from the AcpThread
1895    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1896
1897    // Test model_selector returns Some
1898    let selector_opt = connection.model_selector(&session_id);
1899    assert!(
1900        selector_opt.is_some(),
1901        "agent should always support ModelSelector"
1902    );
1903    let selector = selector_opt.unwrap();
1904
1905    // Test list_models
1906    let listed_models = cx
1907        .update(|cx| selector.list_models(cx))
1908        .await
1909        .expect("list_models should succeed");
1910    let AgentModelList::Grouped(listed_models) = listed_models else {
1911        panic!("Unexpected model list type");
1912    };
1913    assert!(!listed_models.is_empty(), "should have at least one model");
1914    assert_eq!(
1915        listed_models[&AgentModelGroupName("Fake".into())][0]
1916            .id
1917            .0
1918            .as_ref(),
1919        "fake/fake"
1920    );
1921
1922    // Test selected_model returns the default
1923    let model = cx
1924        .update(|cx| selector.selected_model(cx))
1925        .await
1926        .expect("selected_model should succeed");
1927    let model = cx
1928        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1929        .unwrap();
1930    let model = model.as_fake();
1931    assert_eq!(model.id().0, "fake", "should return default model");
1932
1933    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1934    cx.run_until_parked();
1935    model.send_last_completion_stream_text_chunk("def");
1936    cx.run_until_parked();
1937    acp_thread.read_with(cx, |thread, cx| {
1938        assert_eq!(
1939            thread.to_markdown(cx),
1940            indoc! {"
1941                ## User
1942
1943                abc
1944
1945                ## Assistant
1946
1947                def
1948
1949            "}
1950        )
1951    });
1952
1953    // Test cancel
1954    cx.update(|cx| connection.cancel(&session_id, cx));
1955    request.await.expect("prompt should fail gracefully");
1956
1957    // Ensure that dropping the ACP thread causes the native thread to be
1958    // dropped as well.
1959    cx.update(|_| drop(acp_thread));
1960    let result = cx
1961        .update(|cx| {
1962            connection.prompt(
1963                Some(acp_thread::UserMessageId::new()),
1964                acp::PromptRequest {
1965                    session_id: session_id.clone(),
1966                    prompt: vec!["ghi".into()],
1967                    meta: None,
1968                },
1969                cx,
1970            )
1971        })
1972        .await;
1973    assert_eq!(
1974        result.as_ref().unwrap_err().to_string(),
1975        "Session not found",
1976        "unexpected result: {:?}",
1977        result
1978    );
1979}
1980
1981#[gpui::test]
1982async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1983    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1984    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1985    let fake_model = model.as_fake();
1986
1987    let mut events = thread
1988        .update(cx, |thread, cx| {
1989            thread.send(UserMessageId::new(), ["Think"], cx)
1990        })
1991        .unwrap();
1992    cx.run_until_parked();
1993
1994    // Simulate streaming partial input.
1995    let input = json!({});
1996    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1997        LanguageModelToolUse {
1998            id: "1".into(),
1999            name: ThinkingTool::name().into(),
2000            raw_input: input.to_string(),
2001            input,
2002            is_input_complete: false,
2003        },
2004    ));
2005
2006    // Input streaming completed
2007    let input = json!({ "content": "Thinking hard!" });
2008    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2009        LanguageModelToolUse {
2010            id: "1".into(),
2011            name: "thinking".into(),
2012            raw_input: input.to_string(),
2013            input,
2014            is_input_complete: true,
2015        },
2016    ));
2017    fake_model.end_last_completion_stream();
2018    cx.run_until_parked();
2019
2020    let tool_call = expect_tool_call(&mut events).await;
2021    assert_eq!(
2022        tool_call,
2023        acp::ToolCall {
2024            id: acp::ToolCallId("1".into()),
2025            title: "Thinking".into(),
2026            kind: acp::ToolKind::Think,
2027            status: acp::ToolCallStatus::Pending,
2028            content: vec![],
2029            locations: vec![],
2030            raw_input: Some(json!({})),
2031            raw_output: None,
2032            meta: Some(json!({ "tool_name": "thinking" })),
2033        }
2034    );
2035    let update = expect_tool_call_update_fields(&mut events).await;
2036    assert_eq!(
2037        update,
2038        acp::ToolCallUpdate {
2039            id: acp::ToolCallId("1".into()),
2040            fields: acp::ToolCallUpdateFields {
2041                title: Some("Thinking".into()),
2042                kind: Some(acp::ToolKind::Think),
2043                raw_input: Some(json!({ "content": "Thinking hard!" })),
2044                ..Default::default()
2045            },
2046            meta: None,
2047        }
2048    );
2049    let update = expect_tool_call_update_fields(&mut events).await;
2050    assert_eq!(
2051        update,
2052        acp::ToolCallUpdate {
2053            id: acp::ToolCallId("1".into()),
2054            fields: acp::ToolCallUpdateFields {
2055                status: Some(acp::ToolCallStatus::InProgress),
2056                ..Default::default()
2057            },
2058            meta: None,
2059        }
2060    );
2061    let update = expect_tool_call_update_fields(&mut events).await;
2062    assert_eq!(
2063        update,
2064        acp::ToolCallUpdate {
2065            id: acp::ToolCallId("1".into()),
2066            fields: acp::ToolCallUpdateFields {
2067                content: Some(vec!["Thinking hard!".into()]),
2068                ..Default::default()
2069            },
2070            meta: None,
2071        }
2072    );
2073    let update = expect_tool_call_update_fields(&mut events).await;
2074    assert_eq!(
2075        update,
2076        acp::ToolCallUpdate {
2077            id: acp::ToolCallId("1".into()),
2078            fields: acp::ToolCallUpdateFields {
2079                status: Some(acp::ToolCallStatus::Completed),
2080                raw_output: Some("Finished thinking.".into()),
2081                ..Default::default()
2082            },
2083            meta: None,
2084        }
2085    );
2086}
2087
2088#[gpui::test]
2089async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2090    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2091    let fake_model = model.as_fake();
2092
2093    let mut events = thread
2094        .update(cx, |thread, cx| {
2095            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2096            thread.send(UserMessageId::new(), ["Hello!"], cx)
2097        })
2098        .unwrap();
2099    cx.run_until_parked();
2100
2101    fake_model.send_last_completion_stream_text_chunk("Hey!");
2102    fake_model.end_last_completion_stream();
2103
2104    let mut retry_events = Vec::new();
2105    while let Some(Ok(event)) = events.next().await {
2106        match event {
2107            ThreadEvent::Retry(retry_status) => {
2108                retry_events.push(retry_status);
2109            }
2110            ThreadEvent::Stop(..) => break,
2111            _ => {}
2112        }
2113    }
2114
2115    assert_eq!(retry_events.len(), 0);
2116    thread.read_with(cx, |thread, _cx| {
2117        assert_eq!(
2118            thread.to_markdown(),
2119            indoc! {"
2120                ## User
2121
2122                Hello!
2123
2124                ## Assistant
2125
2126                Hey!
2127            "}
2128        )
2129    });
2130}
2131
2132#[gpui::test]
2133async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2134    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2135    let fake_model = model.as_fake();
2136
2137    let mut events = thread
2138        .update(cx, |thread, cx| {
2139            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2140            thread.send(UserMessageId::new(), ["Hello!"], cx)
2141        })
2142        .unwrap();
2143    cx.run_until_parked();
2144
2145    fake_model.send_last_completion_stream_text_chunk("Hey,");
2146    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2147        provider: LanguageModelProviderName::new("Anthropic"),
2148        retry_after: Some(Duration::from_secs(3)),
2149    });
2150    fake_model.end_last_completion_stream();
2151
2152    cx.executor().advance_clock(Duration::from_secs(3));
2153    cx.run_until_parked();
2154
2155    fake_model.send_last_completion_stream_text_chunk("there!");
2156    fake_model.end_last_completion_stream();
2157    cx.run_until_parked();
2158
2159    let mut retry_events = Vec::new();
2160    while let Some(Ok(event)) = events.next().await {
2161        match event {
2162            ThreadEvent::Retry(retry_status) => {
2163                retry_events.push(retry_status);
2164            }
2165            ThreadEvent::Stop(..) => break,
2166            _ => {}
2167        }
2168    }
2169
2170    assert_eq!(retry_events.len(), 1);
2171    assert!(matches!(
2172        retry_events[0],
2173        acp_thread::RetryStatus { attempt: 1, .. }
2174    ));
2175    thread.read_with(cx, |thread, _cx| {
2176        assert_eq!(
2177            thread.to_markdown(),
2178            indoc! {"
2179                ## User
2180
2181                Hello!
2182
2183                ## Assistant
2184
2185                Hey,
2186
2187                [resume]
2188
2189                ## Assistant
2190
2191                there!
2192            "}
2193        )
2194    });
2195}
2196
2197#[gpui::test]
2198async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2199    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2200    let fake_model = model.as_fake();
2201
2202    let events = thread
2203        .update(cx, |thread, cx| {
2204            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2205            thread.add_tool(EchoTool);
2206            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2207        })
2208        .unwrap();
2209    cx.run_until_parked();
2210
2211    let tool_use_1 = LanguageModelToolUse {
2212        id: "tool_1".into(),
2213        name: EchoTool::name().into(),
2214        raw_input: json!({"text": "test"}).to_string(),
2215        input: json!({"text": "test"}),
2216        is_input_complete: true,
2217    };
2218    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2219        tool_use_1.clone(),
2220    ));
2221    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2222        provider: LanguageModelProviderName::new("Anthropic"),
2223        retry_after: Some(Duration::from_secs(3)),
2224    });
2225    fake_model.end_last_completion_stream();
2226
2227    cx.executor().advance_clock(Duration::from_secs(3));
2228    let completion = fake_model.pending_completions().pop().unwrap();
2229    assert_eq!(
2230        completion.messages[1..],
2231        vec![
2232            LanguageModelRequestMessage {
2233                role: Role::User,
2234                content: vec!["Call the echo tool!".into()],
2235                cache: false
2236            },
2237            LanguageModelRequestMessage {
2238                role: Role::Assistant,
2239                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2240                cache: false
2241            },
2242            LanguageModelRequestMessage {
2243                role: Role::User,
2244                content: vec![language_model::MessageContent::ToolResult(
2245                    LanguageModelToolResult {
2246                        tool_use_id: tool_use_1.id.clone(),
2247                        tool_name: tool_use_1.name.clone(),
2248                        is_error: false,
2249                        content: "test".into(),
2250                        output: Some("test".into())
2251                    }
2252                )],
2253                cache: true
2254            },
2255        ]
2256    );
2257
2258    fake_model.send_last_completion_stream_text_chunk("Done");
2259    fake_model.end_last_completion_stream();
2260    cx.run_until_parked();
2261    events.collect::<Vec<_>>().await;
2262    thread.read_with(cx, |thread, _cx| {
2263        assert_eq!(
2264            thread.last_message(),
2265            Some(Message::Agent(AgentMessage {
2266                content: vec![AgentMessageContent::Text("Done".into())],
2267                tool_results: IndexMap::default()
2268            }))
2269        );
2270    })
2271}
2272
2273#[gpui::test]
2274async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2275    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2276    let fake_model = model.as_fake();
2277
2278    let mut events = thread
2279        .update(cx, |thread, cx| {
2280            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2281            thread.send(UserMessageId::new(), ["Hello!"], cx)
2282        })
2283        .unwrap();
2284    cx.run_until_parked();
2285
2286    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2287        fake_model.send_last_completion_stream_error(
2288            LanguageModelCompletionError::ServerOverloaded {
2289                provider: LanguageModelProviderName::new("Anthropic"),
2290                retry_after: Some(Duration::from_secs(3)),
2291            },
2292        );
2293        fake_model.end_last_completion_stream();
2294        cx.executor().advance_clock(Duration::from_secs(3));
2295        cx.run_until_parked();
2296    }
2297
2298    let mut errors = Vec::new();
2299    let mut retry_events = Vec::new();
2300    while let Some(event) = events.next().await {
2301        match event {
2302            Ok(ThreadEvent::Retry(retry_status)) => {
2303                retry_events.push(retry_status);
2304            }
2305            Ok(ThreadEvent::Stop(..)) => break,
2306            Err(error) => errors.push(error),
2307            _ => {}
2308        }
2309    }
2310
2311    assert_eq!(
2312        retry_events.len(),
2313        crate::thread::MAX_RETRY_ATTEMPTS as usize
2314    );
2315    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2316        assert_eq!(retry_events[i].attempt, i + 1);
2317    }
2318    assert_eq!(errors.len(), 1);
2319    let error = errors[0]
2320        .downcast_ref::<LanguageModelCompletionError>()
2321        .unwrap();
2322    assert!(matches!(
2323        error,
2324        LanguageModelCompletionError::ServerOverloaded { .. }
2325    ));
2326}
2327
2328/// Filters out the stop events for asserting against in tests
2329fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2330    result_events
2331        .into_iter()
2332        .filter_map(|event| match event.unwrap() {
2333            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2334            _ => None,
2335        })
2336        .collect()
2337}
2338
2339struct ThreadTest {
2340    model: Arc<dyn LanguageModel>,
2341    thread: Entity<Thread>,
2342    project_context: Entity<ProjectContext>,
2343    context_server_store: Entity<ContextServerStore>,
2344    fs: Arc<FakeFs>,
2345}
2346
2347enum TestModel {
2348    Sonnet4,
2349    Fake,
2350}
2351
2352impl TestModel {
2353    fn id(&self) -> LanguageModelId {
2354        match self {
2355            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2356            TestModel::Fake => unreachable!(),
2357        }
2358    }
2359}
2360
2361async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2362    cx.executor().allow_parking();
2363
2364    let fs = FakeFs::new(cx.background_executor.clone());
2365    fs.create_dir(paths::settings_file().parent().unwrap())
2366        .await
2367        .unwrap();
2368    fs.insert_file(
2369        paths::settings_file(),
2370        json!({
2371            "agent": {
2372                "default_profile": "test-profile",
2373                "profiles": {
2374                    "test-profile": {
2375                        "name": "Test Profile",
2376                        "tools": {
2377                            EchoTool::name(): true,
2378                            DelayTool::name(): true,
2379                            WordListTool::name(): true,
2380                            ToolRequiringPermission::name(): true,
2381                            InfiniteTool::name(): true,
2382                            ThinkingTool::name(): true,
2383                        }
2384                    }
2385                }
2386            }
2387        })
2388        .to_string()
2389        .into_bytes(),
2390    )
2391    .await;
2392
2393    cx.update(|cx| {
2394        settings::init(cx);
2395
2396        match model {
2397            TestModel::Fake => {}
2398            TestModel::Sonnet4 => {
2399                gpui_tokio::init(cx);
2400                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2401                cx.set_http_client(Arc::new(http_client));
2402                let client = Client::production(cx);
2403                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2404                language_model::init(client.clone(), cx);
2405                language_models::init(user_store, client.clone(), cx);
2406            }
2407        };
2408
2409        watch_settings(fs.clone(), cx);
2410    });
2411
2412    let templates = Templates::new();
2413
2414    fs.insert_tree(path!("/test"), json!({})).await;
2415    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2416
2417    let model = cx
2418        .update(|cx| {
2419            if let TestModel::Fake = model {
2420                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2421            } else {
2422                let model_id = model.id();
2423                let models = LanguageModelRegistry::read_global(cx);
2424                let model = models
2425                    .available_models(cx)
2426                    .find(|model| model.id() == model_id)
2427                    .unwrap();
2428
2429                let provider = models.provider(&model.provider_id()).unwrap();
2430                let authenticated = provider.authenticate(cx);
2431
2432                cx.spawn(async move |_cx| {
2433                    authenticated.await.unwrap();
2434                    model
2435                })
2436            }
2437        })
2438        .await;
2439
2440    let project_context = cx.new(|_cx| ProjectContext::default());
2441    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2442    let context_server_registry =
2443        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2444    let thread = cx.new(|cx| {
2445        Thread::new(
2446            project,
2447            project_context.clone(),
2448            context_server_registry,
2449            templates,
2450            Some(model.clone()),
2451            cx,
2452        )
2453    });
2454    ThreadTest {
2455        model,
2456        thread,
2457        project_context,
2458        context_server_store,
2459        fs,
2460    }
2461}
2462
2463#[cfg(test)]
2464#[ctor::ctor]
2465fn init_logger() {
2466    if std::env::var("RUST_LOG").is_ok() {
2467        env_logger::init();
2468    }
2469}
2470
2471fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2472    let fs = fs.clone();
2473    cx.spawn({
2474        async move |cx| {
2475            let mut new_settings_content_rx = settings::watch_config_file(
2476                cx.background_executor(),
2477                fs,
2478                paths::settings_file().clone(),
2479            );
2480
2481            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2482                cx.update(|cx| {
2483                    SettingsStore::update_global(cx, |settings, cx| {
2484                        settings.set_user_settings(&new_settings_content, cx)
2485                    })
2486                })
2487                .ok();
2488            }
2489        }
2490    })
2491    .detach();
2492}
2493
2494fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2495    completion
2496        .tools
2497        .iter()
2498        .map(|tool| tool.name.clone())
2499        .collect()
2500}
2501
2502fn setup_context_server(
2503    name: &'static str,
2504    tools: Vec<context_server::types::Tool>,
2505    context_server_store: &Entity<ContextServerStore>,
2506    cx: &mut TestAppContext,
2507) -> mpsc::UnboundedReceiver<(
2508    context_server::types::CallToolParams,
2509    oneshot::Sender<context_server::types::CallToolResponse>,
2510)> {
2511    cx.update(|cx| {
2512        let mut settings = ProjectSettings::get_global(cx).clone();
2513        settings.context_servers.insert(
2514            name.into(),
2515            project::project_settings::ContextServerSettings::Custom {
2516                enabled: true,
2517                command: ContextServerCommand {
2518                    path: "somebinary".into(),
2519                    args: Vec::new(),
2520                    env: None,
2521                    timeout: None,
2522                },
2523            },
2524        );
2525        ProjectSettings::override_global(settings, cx);
2526    });
2527
2528    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2529    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2530        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2531            context_server::types::InitializeResponse {
2532                protocol_version: context_server::types::ProtocolVersion(
2533                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2534                ),
2535                server_info: context_server::types::Implementation {
2536                    name: name.into(),
2537                    version: "1.0.0".to_string(),
2538                },
2539                capabilities: context_server::types::ServerCapabilities {
2540                    tools: Some(context_server::types::ToolsCapabilities {
2541                        list_changed: Some(true),
2542                    }),
2543                    ..Default::default()
2544                },
2545                meta: None,
2546            }
2547        })
2548        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2549            let tools = tools.clone();
2550            async move {
2551                context_server::types::ListToolsResponse {
2552                    tools,
2553                    next_cursor: None,
2554                    meta: None,
2555                }
2556            }
2557        })
2558        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2559            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2560            async move {
2561                let (response_tx, response_rx) = oneshot::channel();
2562                mcp_tool_calls_tx
2563                    .unbounded_send((params, response_tx))
2564                    .unwrap();
2565                response_rx.await.unwrap()
2566            }
2567        });
2568    context_server_store.update(cx, |store, cx| {
2569        store.start_server(
2570            Arc::new(ContextServer::new(
2571                ContextServerId(name.into()),
2572                Arc::new(fake_transport),
2573            )),
2574            cx,
2575        );
2576    });
2577    cx.run_until_parked();
2578    mcp_tool_calls_rx
2579}