mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17};
  18use gpui::{
  19    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  20};
  21use indoc::indoc;
  22use language_model::{
  23    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  24    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  25    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  26    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  27};
  28use pretty_assertions::assert_eq;
  29use project::{
  30    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  31};
  32use prompt_store::ProjectContext;
  33use reqwest_client::ReqwestClient;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use serde_json::json;
  37use settings::{Settings, SettingsStore};
  38use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  39use util::path;
  40
  41mod test_tools;
  42use test_tools::*;
  43
  44#[gpui::test]
  45async fn test_echo(cx: &mut TestAppContext) {
  46    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  47    let fake_model = model.as_fake();
  48
  49    let events = thread
  50        .update(cx, |thread, cx| {
  51            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  52        })
  53        .unwrap();
  54    cx.run_until_parked();
  55    fake_model.send_last_completion_stream_text_chunk("Hello");
  56    fake_model
  57        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  58    fake_model.end_last_completion_stream();
  59
  60    let events = events.collect().await;
  61    thread.update(cx, |thread, _cx| {
  62        assert_eq!(
  63            thread.last_message().unwrap().to_markdown(),
  64            indoc! {"
  65                ## Assistant
  66
  67                Hello
  68            "}
  69        )
  70    });
  71    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  72}
  73
  74#[gpui::test]
  75async fn test_thinking(cx: &mut TestAppContext) {
  76    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  77    let fake_model = model.as_fake();
  78
  79    let events = thread
  80        .update(cx, |thread, cx| {
  81            thread.send(
  82                UserMessageId::new(),
  83                [indoc! {"
  84                    Testing:
  85
  86                    Generate a thinking step where you just think the word 'Think',
  87                    and have your final answer be 'Hello'
  88                "}],
  89                cx,
  90            )
  91        })
  92        .unwrap();
  93    cx.run_until_parked();
  94    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  95        text: "Think".to_string(),
  96        signature: None,
  97    });
  98    fake_model.send_last_completion_stream_text_chunk("Hello");
  99    fake_model
 100        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 101    fake_model.end_last_completion_stream();
 102
 103    let events = events.collect().await;
 104    thread.update(cx, |thread, _cx| {
 105        assert_eq!(
 106            thread.last_message().unwrap().to_markdown(),
 107            indoc! {"
 108                ## Assistant
 109
 110                <think>Think</think>
 111                Hello
 112            "}
 113        )
 114    });
 115    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 116}
 117
 118#[gpui::test]
 119async fn test_system_prompt(cx: &mut TestAppContext) {
 120    let ThreadTest {
 121        model,
 122        thread,
 123        project_context,
 124        ..
 125    } = setup(cx, TestModel::Fake).await;
 126    let fake_model = model.as_fake();
 127
 128    project_context.update(cx, |project_context, _cx| {
 129        project_context.shell = "test-shell".into()
 130    });
 131    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 132    thread
 133        .update(cx, |thread, cx| {
 134            thread.send(UserMessageId::new(), ["abc"], cx)
 135        })
 136        .unwrap();
 137    cx.run_until_parked();
 138    let mut pending_completions = fake_model.pending_completions();
 139    assert_eq!(
 140        pending_completions.len(),
 141        1,
 142        "unexpected pending completions: {:?}",
 143        pending_completions
 144    );
 145
 146    let pending_completion = pending_completions.pop().unwrap();
 147    assert_eq!(pending_completion.messages[0].role, Role::System);
 148
 149    let system_message = &pending_completion.messages[0];
 150    let system_prompt = system_message.content[0].to_str().unwrap();
 151    assert!(
 152        system_prompt.contains("test-shell"),
 153        "unexpected system message: {:?}",
 154        system_message
 155    );
 156    assert!(
 157        system_prompt.contains("## Fixing Diagnostics"),
 158        "unexpected system message: {:?}",
 159        system_message
 160    );
 161}
 162
 163#[gpui::test]
 164async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 165    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 166    let fake_model = model.as_fake();
 167
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["abc"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174    let mut pending_completions = fake_model.pending_completions();
 175    assert_eq!(
 176        pending_completions.len(),
 177        1,
 178        "unexpected pending completions: {:?}",
 179        pending_completions
 180    );
 181
 182    let pending_completion = pending_completions.pop().unwrap();
 183    assert_eq!(pending_completion.messages[0].role, Role::System);
 184
 185    let system_message = &pending_completion.messages[0];
 186    let system_prompt = system_message.content[0].to_str().unwrap();
 187    assert!(
 188        !system_prompt.contains("## Tool Use"),
 189        "unexpected system message: {:?}",
 190        system_message
 191    );
 192    assert!(
 193        !system_prompt.contains("## Fixing Diagnostics"),
 194        "unexpected system message: {:?}",
 195        system_message
 196    );
 197}
 198
 199#[gpui::test]
 200async fn test_prompt_caching(cx: &mut TestAppContext) {
 201    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 202    let fake_model = model.as_fake();
 203
 204    // Send initial user message and verify it's cached
 205    thread
 206        .update(cx, |thread, cx| {
 207            thread.send(UserMessageId::new(), ["Message 1"], cx)
 208        })
 209        .unwrap();
 210    cx.run_until_parked();
 211
 212    let completion = fake_model.pending_completions().pop().unwrap();
 213    assert_eq!(
 214        completion.messages[1..],
 215        vec![LanguageModelRequestMessage {
 216            role: Role::User,
 217            content: vec!["Message 1".into()],
 218            cache: true
 219        }]
 220    );
 221    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 222        "Response to Message 1".into(),
 223    ));
 224    fake_model.end_last_completion_stream();
 225    cx.run_until_parked();
 226
 227    // Send another user message and verify only the latest is cached
 228    thread
 229        .update(cx, |thread, cx| {
 230            thread.send(UserMessageId::new(), ["Message 2"], cx)
 231        })
 232        .unwrap();
 233    cx.run_until_parked();
 234
 235    let completion = fake_model.pending_completions().pop().unwrap();
 236    assert_eq!(
 237        completion.messages[1..],
 238        vec![
 239            LanguageModelRequestMessage {
 240                role: Role::User,
 241                content: vec!["Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::Assistant,
 246                content: vec!["Response to Message 1".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::User,
 251                content: vec!["Message 2".into()],
 252                cache: true
 253            }
 254        ]
 255    );
 256    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 257        "Response to Message 2".into(),
 258    ));
 259    fake_model.end_last_completion_stream();
 260    cx.run_until_parked();
 261
 262    // Simulate a tool call and verify that the latest tool result is cached
 263    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 264    thread
 265        .update(cx, |thread, cx| {
 266            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 267        })
 268        .unwrap();
 269    cx.run_until_parked();
 270
 271    let tool_use = LanguageModelToolUse {
 272        id: "tool_1".into(),
 273        name: EchoTool::name().into(),
 274        raw_input: json!({"text": "test"}).to_string(),
 275        input: json!({"text": "test"}),
 276        is_input_complete: true,
 277        thought_signature: None,
 278    };
 279    fake_model
 280        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 281    fake_model.end_last_completion_stream();
 282    cx.run_until_parked();
 283
 284    let completion = fake_model.pending_completions().pop().unwrap();
 285    let tool_result = LanguageModelToolResult {
 286        tool_use_id: "tool_1".into(),
 287        tool_name: EchoTool::name().into(),
 288        is_error: false,
 289        content: "test".into(),
 290        output: Some("test".into()),
 291    };
 292    assert_eq!(
 293        completion.messages[1..],
 294        vec![
 295            LanguageModelRequestMessage {
 296                role: Role::User,
 297                content: vec!["Message 1".into()],
 298                cache: false
 299            },
 300            LanguageModelRequestMessage {
 301                role: Role::Assistant,
 302                content: vec!["Response to Message 1".into()],
 303                cache: false
 304            },
 305            LanguageModelRequestMessage {
 306                role: Role::User,
 307                content: vec!["Message 2".into()],
 308                cache: false
 309            },
 310            LanguageModelRequestMessage {
 311                role: Role::Assistant,
 312                content: vec!["Response to Message 2".into()],
 313                cache: false
 314            },
 315            LanguageModelRequestMessage {
 316                role: Role::User,
 317                content: vec!["Use the echo tool".into()],
 318                cache: false
 319            },
 320            LanguageModelRequestMessage {
 321                role: Role::Assistant,
 322                content: vec![MessageContent::ToolUse(tool_use)],
 323                cache: false
 324            },
 325            LanguageModelRequestMessage {
 326                role: Role::User,
 327                content: vec![MessageContent::ToolResult(tool_result)],
 328                cache: true
 329            }
 330        ]
 331    );
 332}
 333
 334#[gpui::test]
 335#[cfg_attr(not(feature = "e2e"), ignore)]
 336async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 337    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 338
 339    // Test a tool call that's likely to complete *before* streaming stops.
 340    let events = thread
 341        .update(cx, |thread, cx| {
 342            thread.add_tool(EchoTool);
 343            thread.send(
 344                UserMessageId::new(),
 345                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 346                cx,
 347            )
 348        })
 349        .unwrap()
 350        .collect()
 351        .await;
 352    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 353
 354    // Test a tool calls that's likely to complete *after* streaming stops.
 355    let events = thread
 356        .update(cx, |thread, cx| {
 357            thread.remove_tool(&EchoTool::name());
 358            thread.add_tool(DelayTool);
 359            thread.send(
 360                UserMessageId::new(),
 361                [
 362                    "Now call the delay tool with 200ms.",
 363                    "When the timer goes off, then you echo the output of the tool.",
 364                ],
 365                cx,
 366            )
 367        })
 368        .unwrap()
 369        .collect()
 370        .await;
 371    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 372    thread.update(cx, |thread, _cx| {
 373        assert!(
 374            thread
 375                .last_message()
 376                .unwrap()
 377                .as_agent_message()
 378                .unwrap()
 379                .content
 380                .iter()
 381                .any(|content| {
 382                    if let AgentMessageContent::Text(text) = content {
 383                        text.contains("Ding")
 384                    } else {
 385                        false
 386                    }
 387                }),
 388            "{}",
 389            thread.to_markdown()
 390        );
 391    });
 392}
 393
 394#[gpui::test]
 395#[cfg_attr(not(feature = "e2e"), ignore)]
 396async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 397    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 398
 399    // Test a tool call that's likely to complete *before* streaming stops.
 400    let mut events = thread
 401        .update(cx, |thread, cx| {
 402            thread.add_tool(WordListTool);
 403            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 404        })
 405        .unwrap();
 406
 407    let mut saw_partial_tool_use = false;
 408    while let Some(event) = events.next().await {
 409        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 410            thread.update(cx, |thread, _cx| {
 411                // Look for a tool use in the thread's last message
 412                let message = thread.last_message().unwrap();
 413                let agent_message = message.as_agent_message().unwrap();
 414                let last_content = agent_message.content.last().unwrap();
 415                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 416                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 417                    if tool_call.status == acp::ToolCallStatus::Pending {
 418                        if !last_tool_use.is_input_complete
 419                            && last_tool_use.input.get("g").is_none()
 420                        {
 421                            saw_partial_tool_use = true;
 422                        }
 423                    } else {
 424                        last_tool_use
 425                            .input
 426                            .get("a")
 427                            .expect("'a' has streamed because input is now complete");
 428                        last_tool_use
 429                            .input
 430                            .get("g")
 431                            .expect("'g' has streamed because input is now complete");
 432                    }
 433                } else {
 434                    panic!("last content should be a tool use");
 435                }
 436            });
 437        }
 438    }
 439
 440    assert!(
 441        saw_partial_tool_use,
 442        "should see at least one partially streamed tool use in the history"
 443    );
 444}
 445
 446#[gpui::test]
 447async fn test_tool_authorization(cx: &mut TestAppContext) {
 448    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 449    let fake_model = model.as_fake();
 450
 451    let mut events = thread
 452        .update(cx, |thread, cx| {
 453            thread.add_tool(ToolRequiringPermission);
 454            thread.send(UserMessageId::new(), ["abc"], cx)
 455        })
 456        .unwrap();
 457    cx.run_until_parked();
 458    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 459        LanguageModelToolUse {
 460            id: "tool_id_1".into(),
 461            name: ToolRequiringPermission::name().into(),
 462            raw_input: "{}".into(),
 463            input: json!({}),
 464            is_input_complete: true,
 465            thought_signature: None,
 466        },
 467    ));
 468    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 469        LanguageModelToolUse {
 470            id: "tool_id_2".into(),
 471            name: ToolRequiringPermission::name().into(),
 472            raw_input: "{}".into(),
 473            input: json!({}),
 474            is_input_complete: true,
 475            thought_signature: None,
 476        },
 477    ));
 478    fake_model.end_last_completion_stream();
 479    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 480    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 481
 482    // Approve the first
 483    tool_call_auth_1
 484        .response
 485        .send(tool_call_auth_1.options[1].id.clone())
 486        .unwrap();
 487    cx.run_until_parked();
 488
 489    // Reject the second
 490    tool_call_auth_2
 491        .response
 492        .send(tool_call_auth_1.options[2].id.clone())
 493        .unwrap();
 494    cx.run_until_parked();
 495
 496    let completion = fake_model.pending_completions().pop().unwrap();
 497    let message = completion.messages.last().unwrap();
 498    assert_eq!(
 499        message.content,
 500        vec![
 501            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 502                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 503                tool_name: ToolRequiringPermission::name().into(),
 504                is_error: false,
 505                content: "Allowed".into(),
 506                output: Some("Allowed".into())
 507            }),
 508            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 509                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 510                tool_name: ToolRequiringPermission::name().into(),
 511                is_error: true,
 512                content: "Permission to run tool denied by user".into(),
 513                output: Some("Permission to run tool denied by user".into())
 514            })
 515        ]
 516    );
 517
 518    // Simulate yet another tool call.
 519    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 520        LanguageModelToolUse {
 521            id: "tool_id_3".into(),
 522            name: ToolRequiringPermission::name().into(),
 523            raw_input: "{}".into(),
 524            input: json!({}),
 525            is_input_complete: true,
 526            thought_signature: None,
 527        },
 528    ));
 529    fake_model.end_last_completion_stream();
 530
 531    // Respond by always allowing tools.
 532    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 533    tool_call_auth_3
 534        .response
 535        .send(tool_call_auth_3.options[0].id.clone())
 536        .unwrap();
 537    cx.run_until_parked();
 538    let completion = fake_model.pending_completions().pop().unwrap();
 539    let message = completion.messages.last().unwrap();
 540    assert_eq!(
 541        message.content,
 542        vec![language_model::MessageContent::ToolResult(
 543            LanguageModelToolResult {
 544                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 545                tool_name: ToolRequiringPermission::name().into(),
 546                is_error: false,
 547                content: "Allowed".into(),
 548                output: Some("Allowed".into())
 549            }
 550        )]
 551    );
 552
 553    // Simulate a final tool call, ensuring we don't trigger authorization.
 554    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 555        LanguageModelToolUse {
 556            id: "tool_id_4".into(),
 557            name: ToolRequiringPermission::name().into(),
 558            raw_input: "{}".into(),
 559            input: json!({}),
 560            is_input_complete: true,
 561            thought_signature: None,
 562        },
 563    ));
 564    fake_model.end_last_completion_stream();
 565    cx.run_until_parked();
 566    let completion = fake_model.pending_completions().pop().unwrap();
 567    let message = completion.messages.last().unwrap();
 568    assert_eq!(
 569        message.content,
 570        vec![language_model::MessageContent::ToolResult(
 571            LanguageModelToolResult {
 572                tool_use_id: "tool_id_4".into(),
 573                tool_name: ToolRequiringPermission::name().into(),
 574                is_error: false,
 575                content: "Allowed".into(),
 576                output: Some("Allowed".into())
 577            }
 578        )]
 579    );
 580}
 581
 582#[gpui::test]
 583async fn test_tool_hallucination(cx: &mut TestAppContext) {
 584    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 585    let fake_model = model.as_fake();
 586
 587    let mut events = thread
 588        .update(cx, |thread, cx| {
 589            thread.send(UserMessageId::new(), ["abc"], cx)
 590        })
 591        .unwrap();
 592    cx.run_until_parked();
 593    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 594        LanguageModelToolUse {
 595            id: "tool_id_1".into(),
 596            name: "nonexistent_tool".into(),
 597            raw_input: "{}".into(),
 598            input: json!({}),
 599            is_input_complete: true,
 600            thought_signature: None,
 601        },
 602    ));
 603    fake_model.end_last_completion_stream();
 604
 605    let tool_call = expect_tool_call(&mut events).await;
 606    assert_eq!(tool_call.title, "nonexistent_tool");
 607    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 608    let update = expect_tool_call_update_fields(&mut events).await;
 609    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 610}
 611
 612#[gpui::test]
 613async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 614    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 615    let fake_model = model.as_fake();
 616
 617    let events = thread
 618        .update(cx, |thread, cx| {
 619            thread.add_tool(EchoTool);
 620            thread.send(UserMessageId::new(), ["abc"], cx)
 621        })
 622        .unwrap();
 623    cx.run_until_parked();
 624    let tool_use = LanguageModelToolUse {
 625        id: "tool_id_1".into(),
 626        name: EchoTool::name().into(),
 627        raw_input: "{}".into(),
 628        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 629        is_input_complete: true,
 630        thought_signature: None,
 631    };
 632    fake_model
 633        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 634    fake_model.end_last_completion_stream();
 635
 636    cx.run_until_parked();
 637    let completion = fake_model.pending_completions().pop().unwrap();
 638    let tool_result = LanguageModelToolResult {
 639        tool_use_id: "tool_id_1".into(),
 640        tool_name: EchoTool::name().into(),
 641        is_error: false,
 642        content: "def".into(),
 643        output: Some("def".into()),
 644    };
 645    assert_eq!(
 646        completion.messages[1..],
 647        vec![
 648            LanguageModelRequestMessage {
 649                role: Role::User,
 650                content: vec!["abc".into()],
 651                cache: false
 652            },
 653            LanguageModelRequestMessage {
 654                role: Role::Assistant,
 655                content: vec![MessageContent::ToolUse(tool_use.clone())],
 656                cache: false
 657            },
 658            LanguageModelRequestMessage {
 659                role: Role::User,
 660                content: vec![MessageContent::ToolResult(tool_result.clone())],
 661                cache: true
 662            },
 663        ]
 664    );
 665
 666    // Simulate reaching tool use limit.
 667    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 668    fake_model.end_last_completion_stream();
 669    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 670    assert!(
 671        last_event
 672            .unwrap_err()
 673            .is::<language_model::ToolUseLimitReachedError>()
 674    );
 675
 676    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 677    cx.run_until_parked();
 678    let completion = fake_model.pending_completions().pop().unwrap();
 679    assert_eq!(
 680        completion.messages[1..],
 681        vec![
 682            LanguageModelRequestMessage {
 683                role: Role::User,
 684                content: vec!["abc".into()],
 685                cache: false
 686            },
 687            LanguageModelRequestMessage {
 688                role: Role::Assistant,
 689                content: vec![MessageContent::ToolUse(tool_use)],
 690                cache: false
 691            },
 692            LanguageModelRequestMessage {
 693                role: Role::User,
 694                content: vec![MessageContent::ToolResult(tool_result)],
 695                cache: false
 696            },
 697            LanguageModelRequestMessage {
 698                role: Role::User,
 699                content: vec!["Continue where you left off".into()],
 700                cache: true
 701            }
 702        ]
 703    );
 704
 705    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 706    fake_model.end_last_completion_stream();
 707    events.collect::<Vec<_>>().await;
 708    thread.read_with(cx, |thread, _cx| {
 709        assert_eq!(
 710            thread.last_message().unwrap().to_markdown(),
 711            indoc! {"
 712                ## Assistant
 713
 714                Done
 715            "}
 716        )
 717    });
 718}
 719
 720#[gpui::test]
 721async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 722    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 723    let fake_model = model.as_fake();
 724
 725    let events = thread
 726        .update(cx, |thread, cx| {
 727            thread.add_tool(EchoTool);
 728            thread.send(UserMessageId::new(), ["abc"], cx)
 729        })
 730        .unwrap();
 731    cx.run_until_parked();
 732
 733    let tool_use = LanguageModelToolUse {
 734        id: "tool_id_1".into(),
 735        name: EchoTool::name().into(),
 736        raw_input: "{}".into(),
 737        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 738        is_input_complete: true,
 739        thought_signature: None,
 740    };
 741    let tool_result = LanguageModelToolResult {
 742        tool_use_id: "tool_id_1".into(),
 743        tool_name: EchoTool::name().into(),
 744        is_error: false,
 745        content: "def".into(),
 746        output: Some("def".into()),
 747    };
 748    fake_model
 749        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 750    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 751    fake_model.end_last_completion_stream();
 752    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 753    assert!(
 754        last_event
 755            .unwrap_err()
 756            .is::<language_model::ToolUseLimitReachedError>()
 757    );
 758
 759    thread
 760        .update(cx, |thread, cx| {
 761            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 762        })
 763        .unwrap();
 764    cx.run_until_parked();
 765    let completion = fake_model.pending_completions().pop().unwrap();
 766    assert_eq!(
 767        completion.messages[1..],
 768        vec![
 769            LanguageModelRequestMessage {
 770                role: Role::User,
 771                content: vec!["abc".into()],
 772                cache: false
 773            },
 774            LanguageModelRequestMessage {
 775                role: Role::Assistant,
 776                content: vec![MessageContent::ToolUse(tool_use)],
 777                cache: false
 778            },
 779            LanguageModelRequestMessage {
 780                role: Role::User,
 781                content: vec![MessageContent::ToolResult(tool_result)],
 782                cache: false
 783            },
 784            LanguageModelRequestMessage {
 785                role: Role::User,
 786                content: vec!["ghi".into()],
 787                cache: true
 788            }
 789        ]
 790    );
 791}
 792
 793async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 794    let event = events
 795        .next()
 796        .await
 797        .expect("no tool call authorization event received")
 798        .unwrap();
 799    match event {
 800        ThreadEvent::ToolCall(tool_call) => tool_call,
 801        event => {
 802            panic!("Unexpected event {event:?}");
 803        }
 804    }
 805}
 806
 807async fn expect_tool_call_update_fields(
 808    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 809) -> acp::ToolCallUpdate {
 810    let event = events
 811        .next()
 812        .await
 813        .expect("no tool call authorization event received")
 814        .unwrap();
 815    match event {
 816        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 817        event => {
 818            panic!("Unexpected event {event:?}");
 819        }
 820    }
 821}
 822
 823async fn next_tool_call_authorization(
 824    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 825) -> ToolCallAuthorization {
 826    loop {
 827        let event = events
 828            .next()
 829            .await
 830            .expect("no tool call authorization event received")
 831            .unwrap();
 832        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 833            let permission_kinds = tool_call_authorization
 834                .options
 835                .iter()
 836                .map(|o| o.kind)
 837                .collect::<Vec<_>>();
 838            assert_eq!(
 839                permission_kinds,
 840                vec![
 841                    acp::PermissionOptionKind::AllowAlways,
 842                    acp::PermissionOptionKind::AllowOnce,
 843                    acp::PermissionOptionKind::RejectOnce,
 844                ]
 845            );
 846            return tool_call_authorization;
 847        }
 848    }
 849}
 850
 851#[gpui::test]
 852#[cfg_attr(not(feature = "e2e"), ignore)]
 853async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 854    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 855
 856    // Test concurrent tool calls with different delay times
 857    let events = thread
 858        .update(cx, |thread, cx| {
 859            thread.add_tool(DelayTool);
 860            thread.send(
 861                UserMessageId::new(),
 862                [
 863                    "Call the delay tool twice in the same message.",
 864                    "Once with 100ms. Once with 300ms.",
 865                    "When both timers are complete, describe the outputs.",
 866                ],
 867                cx,
 868            )
 869        })
 870        .unwrap()
 871        .collect()
 872        .await;
 873
 874    let stop_reasons = stop_events(events);
 875    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 876
 877    thread.update(cx, |thread, _cx| {
 878        let last_message = thread.last_message().unwrap();
 879        let agent_message = last_message.as_agent_message().unwrap();
 880        let text = agent_message
 881            .content
 882            .iter()
 883            .filter_map(|content| {
 884                if let AgentMessageContent::Text(text) = content {
 885                    Some(text.as_str())
 886                } else {
 887                    None
 888                }
 889            })
 890            .collect::<String>();
 891
 892        assert!(text.contains("Ding"));
 893    });
 894}
 895
 896#[gpui::test]
 897async fn test_profiles(cx: &mut TestAppContext) {
 898    let ThreadTest {
 899        model, thread, fs, ..
 900    } = setup(cx, TestModel::Fake).await;
 901    let fake_model = model.as_fake();
 902
 903    thread.update(cx, |thread, _cx| {
 904        thread.add_tool(DelayTool);
 905        thread.add_tool(EchoTool);
 906        thread.add_tool(InfiniteTool);
 907    });
 908
 909    // Override profiles and wait for settings to be loaded.
 910    fs.insert_file(
 911        paths::settings_file(),
 912        json!({
 913            "agent": {
 914                "profiles": {
 915                    "test-1": {
 916                        "name": "Test Profile 1",
 917                        "tools": {
 918                            EchoTool::name(): true,
 919                            DelayTool::name(): true,
 920                        }
 921                    },
 922                    "test-2": {
 923                        "name": "Test Profile 2",
 924                        "tools": {
 925                            InfiniteTool::name(): true,
 926                        }
 927                    }
 928                }
 929            }
 930        })
 931        .to_string()
 932        .into_bytes(),
 933    )
 934    .await;
 935    cx.run_until_parked();
 936
 937    // Test that test-1 profile (default) has echo and delay tools
 938    thread
 939        .update(cx, |thread, cx| {
 940            thread.set_profile(AgentProfileId("test-1".into()), cx);
 941            thread.send(UserMessageId::new(), ["test"], cx)
 942        })
 943        .unwrap();
 944    cx.run_until_parked();
 945
 946    let mut pending_completions = fake_model.pending_completions();
 947    assert_eq!(pending_completions.len(), 1);
 948    let completion = pending_completions.pop().unwrap();
 949    let tool_names: Vec<String> = completion
 950        .tools
 951        .iter()
 952        .map(|tool| tool.name.clone())
 953        .collect();
 954    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 955    fake_model.end_last_completion_stream();
 956
 957    // Switch to test-2 profile, and verify that it has only the infinite tool.
 958    thread
 959        .update(cx, |thread, cx| {
 960            thread.set_profile(AgentProfileId("test-2".into()), cx);
 961            thread.send(UserMessageId::new(), ["test2"], cx)
 962        })
 963        .unwrap();
 964    cx.run_until_parked();
 965    let mut pending_completions = fake_model.pending_completions();
 966    assert_eq!(pending_completions.len(), 1);
 967    let completion = pending_completions.pop().unwrap();
 968    let tool_names: Vec<String> = completion
 969        .tools
 970        .iter()
 971        .map(|tool| tool.name.clone())
 972        .collect();
 973    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 974}
 975
 976#[gpui::test]
 977async fn test_mcp_tools(cx: &mut TestAppContext) {
 978    let ThreadTest {
 979        model,
 980        thread,
 981        context_server_store,
 982        fs,
 983        ..
 984    } = setup(cx, TestModel::Fake).await;
 985    let fake_model = model.as_fake();
 986
 987    // Override profiles and wait for settings to be loaded.
 988    fs.insert_file(
 989        paths::settings_file(),
 990        json!({
 991            "agent": {
 992                "always_allow_tool_actions": true,
 993                "profiles": {
 994                    "test": {
 995                        "name": "Test Profile",
 996                        "enable_all_context_servers": true,
 997                        "tools": {
 998                            EchoTool::name(): true,
 999                        }
1000                    },
1001                }
1002            }
1003        })
1004        .to_string()
1005        .into_bytes(),
1006    )
1007    .await;
1008    cx.run_until_parked();
1009    thread.update(cx, |thread, cx| {
1010        thread.set_profile(AgentProfileId("test".into()), cx)
1011    });
1012
1013    let mut mcp_tool_calls = setup_context_server(
1014        "test_server",
1015        vec![context_server::types::Tool {
1016            name: "echo".into(),
1017            description: None,
1018            input_schema: serde_json::to_value(EchoTool::input_schema(
1019                LanguageModelToolSchemaFormat::JsonSchema,
1020            ))
1021            .unwrap(),
1022            output_schema: None,
1023            annotations: None,
1024        }],
1025        &context_server_store,
1026        cx,
1027    );
1028
1029    let events = thread.update(cx, |thread, cx| {
1030        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1031    });
1032    cx.run_until_parked();
1033
1034    // Simulate the model calling the MCP tool.
1035    let completion = fake_model.pending_completions().pop().unwrap();
1036    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1037    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1038        LanguageModelToolUse {
1039            id: "tool_1".into(),
1040            name: "echo".into(),
1041            raw_input: json!({"text": "test"}).to_string(),
1042            input: json!({"text": "test"}),
1043            is_input_complete: true,
1044            thought_signature: None,
1045        },
1046    ));
1047    fake_model.end_last_completion_stream();
1048    cx.run_until_parked();
1049
1050    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1051    assert_eq!(tool_call_params.name, "echo");
1052    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1053    tool_call_response
1054        .send(context_server::types::CallToolResponse {
1055            content: vec![context_server::types::ToolResponseContent::Text {
1056                text: "test".into(),
1057            }],
1058            is_error: None,
1059            meta: None,
1060            structured_content: None,
1061        })
1062        .unwrap();
1063    cx.run_until_parked();
1064
1065    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1066    fake_model.send_last_completion_stream_text_chunk("Done!");
1067    fake_model.end_last_completion_stream();
1068    events.collect::<Vec<_>>().await;
1069
1070    // Send again after adding the echo tool, ensuring the name collision is resolved.
1071    let events = thread.update(cx, |thread, cx| {
1072        thread.add_tool(EchoTool);
1073        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1074    });
1075    cx.run_until_parked();
1076    let completion = fake_model.pending_completions().pop().unwrap();
1077    assert_eq!(
1078        tool_names_for_completion(&completion),
1079        vec!["echo", "test_server_echo"]
1080    );
1081    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1082        LanguageModelToolUse {
1083            id: "tool_2".into(),
1084            name: "test_server_echo".into(),
1085            raw_input: json!({"text": "mcp"}).to_string(),
1086            input: json!({"text": "mcp"}),
1087            is_input_complete: true,
1088            thought_signature: None,
1089        },
1090    ));
1091    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1092        LanguageModelToolUse {
1093            id: "tool_3".into(),
1094            name: "echo".into(),
1095            raw_input: json!({"text": "native"}).to_string(),
1096            input: json!({"text": "native"}),
1097            is_input_complete: true,
1098            thought_signature: None,
1099        },
1100    ));
1101    fake_model.end_last_completion_stream();
1102    cx.run_until_parked();
1103
1104    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1105    assert_eq!(tool_call_params.name, "echo");
1106    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1107    tool_call_response
1108        .send(context_server::types::CallToolResponse {
1109            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1110            is_error: None,
1111            meta: None,
1112            structured_content: None,
1113        })
1114        .unwrap();
1115    cx.run_until_parked();
1116
1117    // Ensure the tool results were inserted with the correct names.
1118    let completion = fake_model.pending_completions().pop().unwrap();
1119    assert_eq!(
1120        completion.messages.last().unwrap().content,
1121        vec![
1122            MessageContent::ToolResult(LanguageModelToolResult {
1123                tool_use_id: "tool_3".into(),
1124                tool_name: "echo".into(),
1125                is_error: false,
1126                content: "native".into(),
1127                output: Some("native".into()),
1128            },),
1129            MessageContent::ToolResult(LanguageModelToolResult {
1130                tool_use_id: "tool_2".into(),
1131                tool_name: "test_server_echo".into(),
1132                is_error: false,
1133                content: "mcp".into(),
1134                output: Some("mcp".into()),
1135            },),
1136        ]
1137    );
1138    fake_model.end_last_completion_stream();
1139    events.collect::<Vec<_>>().await;
1140}
1141
1142#[gpui::test]
1143async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1144    let ThreadTest {
1145        model,
1146        thread,
1147        context_server_store,
1148        fs,
1149        ..
1150    } = setup(cx, TestModel::Fake).await;
1151    let fake_model = model.as_fake();
1152
1153    // Set up a profile with all tools enabled
1154    fs.insert_file(
1155        paths::settings_file(),
1156        json!({
1157            "agent": {
1158                "profiles": {
1159                    "test": {
1160                        "name": "Test Profile",
1161                        "enable_all_context_servers": true,
1162                        "tools": {
1163                            EchoTool::name(): true,
1164                            DelayTool::name(): true,
1165                            WordListTool::name(): true,
1166                            ToolRequiringPermission::name(): true,
1167                            InfiniteTool::name(): true,
1168                        }
1169                    },
1170                }
1171            }
1172        })
1173        .to_string()
1174        .into_bytes(),
1175    )
1176    .await;
1177    cx.run_until_parked();
1178
1179    thread.update(cx, |thread, cx| {
1180        thread.set_profile(AgentProfileId("test".into()), cx);
1181        thread.add_tool(EchoTool);
1182        thread.add_tool(DelayTool);
1183        thread.add_tool(WordListTool);
1184        thread.add_tool(ToolRequiringPermission);
1185        thread.add_tool(InfiniteTool);
1186    });
1187
1188    // Set up multiple context servers with some overlapping tool names
1189    let _server1_calls = setup_context_server(
1190        "xxx",
1191        vec![
1192            context_server::types::Tool {
1193                name: "echo".into(), // Conflicts with native EchoTool
1194                description: None,
1195                input_schema: serde_json::to_value(EchoTool::input_schema(
1196                    LanguageModelToolSchemaFormat::JsonSchema,
1197                ))
1198                .unwrap(),
1199                output_schema: None,
1200                annotations: None,
1201            },
1202            context_server::types::Tool {
1203                name: "unique_tool_1".into(),
1204                description: None,
1205                input_schema: json!({"type": "object", "properties": {}}),
1206                output_schema: None,
1207                annotations: None,
1208            },
1209        ],
1210        &context_server_store,
1211        cx,
1212    );
1213
1214    let _server2_calls = setup_context_server(
1215        "yyy",
1216        vec![
1217            context_server::types::Tool {
1218                name: "echo".into(), // Also conflicts with native EchoTool
1219                description: None,
1220                input_schema: serde_json::to_value(EchoTool::input_schema(
1221                    LanguageModelToolSchemaFormat::JsonSchema,
1222                ))
1223                .unwrap(),
1224                output_schema: None,
1225                annotations: None,
1226            },
1227            context_server::types::Tool {
1228                name: "unique_tool_2".into(),
1229                description: None,
1230                input_schema: json!({"type": "object", "properties": {}}),
1231                output_schema: None,
1232                annotations: None,
1233            },
1234            context_server::types::Tool {
1235                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1236                description: None,
1237                input_schema: json!({"type": "object", "properties": {}}),
1238                output_schema: None,
1239                annotations: None,
1240            },
1241            context_server::types::Tool {
1242                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1243                description: None,
1244                input_schema: json!({"type": "object", "properties": {}}),
1245                output_schema: None,
1246                annotations: None,
1247            },
1248        ],
1249        &context_server_store,
1250        cx,
1251    );
1252    let _server3_calls = setup_context_server(
1253        "zzz",
1254        vec![
1255            context_server::types::Tool {
1256                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1257                description: None,
1258                input_schema: json!({"type": "object", "properties": {}}),
1259                output_schema: None,
1260                annotations: None,
1261            },
1262            context_server::types::Tool {
1263                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1264                description: None,
1265                input_schema: json!({"type": "object", "properties": {}}),
1266                output_schema: None,
1267                annotations: None,
1268            },
1269            context_server::types::Tool {
1270                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1271                description: None,
1272                input_schema: json!({"type": "object", "properties": {}}),
1273                output_schema: None,
1274                annotations: None,
1275            },
1276        ],
1277        &context_server_store,
1278        cx,
1279    );
1280
1281    thread
1282        .update(cx, |thread, cx| {
1283            thread.send(UserMessageId::new(), ["Go"], cx)
1284        })
1285        .unwrap();
1286    cx.run_until_parked();
1287    let completion = fake_model.pending_completions().pop().unwrap();
1288    assert_eq!(
1289        tool_names_for_completion(&completion),
1290        vec![
1291            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1292            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1293            "delay",
1294            "echo",
1295            "infinite",
1296            "tool_requiring_permission",
1297            "unique_tool_1",
1298            "unique_tool_2",
1299            "word_list",
1300            "xxx_echo",
1301            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1302            "yyy_echo",
1303            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1304        ]
1305    );
1306}
1307
1308#[gpui::test]
1309#[cfg_attr(not(feature = "e2e"), ignore)]
1310async fn test_cancellation(cx: &mut TestAppContext) {
1311    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1312
1313    let mut events = thread
1314        .update(cx, |thread, cx| {
1315            thread.add_tool(InfiniteTool);
1316            thread.add_tool(EchoTool);
1317            thread.send(
1318                UserMessageId::new(),
1319                ["Call the echo tool, then call the infinite tool, then explain their output"],
1320                cx,
1321            )
1322        })
1323        .unwrap();
1324
1325    // Wait until both tools are called.
1326    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1327    let mut echo_id = None;
1328    let mut echo_completed = false;
1329    while let Some(event) = events.next().await {
1330        match event.unwrap() {
1331            ThreadEvent::ToolCall(tool_call) => {
1332                assert_eq!(tool_call.title, expected_tools.remove(0));
1333                if tool_call.title == "Echo" {
1334                    echo_id = Some(tool_call.id);
1335                }
1336            }
1337            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1338                acp::ToolCallUpdate {
1339                    id,
1340                    fields:
1341                        acp::ToolCallUpdateFields {
1342                            status: Some(acp::ToolCallStatus::Completed),
1343                            ..
1344                        },
1345                    meta: None,
1346                },
1347            )) if Some(&id) == echo_id.as_ref() => {
1348                echo_completed = true;
1349            }
1350            _ => {}
1351        }
1352
1353        if expected_tools.is_empty() && echo_completed {
1354            break;
1355        }
1356    }
1357
1358    // Cancel the current send and ensure that the event stream is closed, even
1359    // if one of the tools is still running.
1360    thread.update(cx, |thread, cx| thread.cancel(cx));
1361    let events = events.collect::<Vec<_>>().await;
1362    let last_event = events.last();
1363    assert!(
1364        matches!(
1365            last_event,
1366            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1367        ),
1368        "unexpected event {last_event:?}"
1369    );
1370
1371    // Ensure we can still send a new message after cancellation.
1372    let events = thread
1373        .update(cx, |thread, cx| {
1374            thread.send(
1375                UserMessageId::new(),
1376                ["Testing: reply with 'Hello' then stop."],
1377                cx,
1378            )
1379        })
1380        .unwrap()
1381        .collect::<Vec<_>>()
1382        .await;
1383    thread.update(cx, |thread, _cx| {
1384        let message = thread.last_message().unwrap();
1385        let agent_message = message.as_agent_message().unwrap();
1386        assert_eq!(
1387            agent_message.content,
1388            vec![AgentMessageContent::Text("Hello".to_string())]
1389        );
1390    });
1391    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1392}
1393
1394#[gpui::test]
1395async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1396    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1397    let fake_model = model.as_fake();
1398
1399    let events_1 = thread
1400        .update(cx, |thread, cx| {
1401            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1402        })
1403        .unwrap();
1404    cx.run_until_parked();
1405    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1406    cx.run_until_parked();
1407
1408    let events_2 = thread
1409        .update(cx, |thread, cx| {
1410            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1411        })
1412        .unwrap();
1413    cx.run_until_parked();
1414    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1415    fake_model
1416        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1417    fake_model.end_last_completion_stream();
1418
1419    let events_1 = events_1.collect::<Vec<_>>().await;
1420    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1421    let events_2 = events_2.collect::<Vec<_>>().await;
1422    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1423}
1424
1425#[gpui::test]
1426async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1427    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1428    let fake_model = model.as_fake();
1429
1430    let events_1 = thread
1431        .update(cx, |thread, cx| {
1432            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1433        })
1434        .unwrap();
1435    cx.run_until_parked();
1436    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1437    fake_model
1438        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1439    fake_model.end_last_completion_stream();
1440    let events_1 = events_1.collect::<Vec<_>>().await;
1441
1442    let events_2 = thread
1443        .update(cx, |thread, cx| {
1444            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1445        })
1446        .unwrap();
1447    cx.run_until_parked();
1448    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1449    fake_model
1450        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1451    fake_model.end_last_completion_stream();
1452    let events_2 = events_2.collect::<Vec<_>>().await;
1453
1454    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1455    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1456}
1457
1458#[gpui::test]
1459async fn test_refusal(cx: &mut TestAppContext) {
1460    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1461    let fake_model = model.as_fake();
1462
1463    let events = thread
1464        .update(cx, |thread, cx| {
1465            thread.send(UserMessageId::new(), ["Hello"], cx)
1466        })
1467        .unwrap();
1468    cx.run_until_parked();
1469    thread.read_with(cx, |thread, _| {
1470        assert_eq!(
1471            thread.to_markdown(),
1472            indoc! {"
1473                ## User
1474
1475                Hello
1476            "}
1477        );
1478    });
1479
1480    fake_model.send_last_completion_stream_text_chunk("Hey!");
1481    cx.run_until_parked();
1482    thread.read_with(cx, |thread, _| {
1483        assert_eq!(
1484            thread.to_markdown(),
1485            indoc! {"
1486                ## User
1487
1488                Hello
1489
1490                ## Assistant
1491
1492                Hey!
1493            "}
1494        );
1495    });
1496
1497    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1498    fake_model
1499        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1500    let events = events.collect::<Vec<_>>().await;
1501    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1502    thread.read_with(cx, |thread, _| {
1503        assert_eq!(thread.to_markdown(), "");
1504    });
1505}
1506
1507#[gpui::test]
1508async fn test_truncate_first_message(cx: &mut TestAppContext) {
1509    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1510    let fake_model = model.as_fake();
1511
1512    let message_id = UserMessageId::new();
1513    thread
1514        .update(cx, |thread, cx| {
1515            thread.send(message_id.clone(), ["Hello"], cx)
1516        })
1517        .unwrap();
1518    cx.run_until_parked();
1519    thread.read_with(cx, |thread, _| {
1520        assert_eq!(
1521            thread.to_markdown(),
1522            indoc! {"
1523                ## User
1524
1525                Hello
1526            "}
1527        );
1528        assert_eq!(thread.latest_token_usage(), None);
1529    });
1530
1531    fake_model.send_last_completion_stream_text_chunk("Hey!");
1532    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1533        language_model::TokenUsage {
1534            input_tokens: 32_000,
1535            output_tokens: 16_000,
1536            cache_creation_input_tokens: 0,
1537            cache_read_input_tokens: 0,
1538        },
1539    ));
1540    cx.run_until_parked();
1541    thread.read_with(cx, |thread, _| {
1542        assert_eq!(
1543            thread.to_markdown(),
1544            indoc! {"
1545                ## User
1546
1547                Hello
1548
1549                ## Assistant
1550
1551                Hey!
1552            "}
1553        );
1554        assert_eq!(
1555            thread.latest_token_usage(),
1556            Some(acp_thread::TokenUsage {
1557                used_tokens: 32_000 + 16_000,
1558                max_tokens: 1_000_000,
1559            })
1560        );
1561    });
1562
1563    thread
1564        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1565        .unwrap();
1566    cx.run_until_parked();
1567    thread.read_with(cx, |thread, _| {
1568        assert_eq!(thread.to_markdown(), "");
1569        assert_eq!(thread.latest_token_usage(), None);
1570    });
1571
1572    // Ensure we can still send a new message after truncation.
1573    thread
1574        .update(cx, |thread, cx| {
1575            thread.send(UserMessageId::new(), ["Hi"], cx)
1576        })
1577        .unwrap();
1578    thread.update(cx, |thread, _cx| {
1579        assert_eq!(
1580            thread.to_markdown(),
1581            indoc! {"
1582                ## User
1583
1584                Hi
1585            "}
1586        );
1587    });
1588    cx.run_until_parked();
1589    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1590    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1591        language_model::TokenUsage {
1592            input_tokens: 40_000,
1593            output_tokens: 20_000,
1594            cache_creation_input_tokens: 0,
1595            cache_read_input_tokens: 0,
1596        },
1597    ));
1598    cx.run_until_parked();
1599    thread.read_with(cx, |thread, _| {
1600        assert_eq!(
1601            thread.to_markdown(),
1602            indoc! {"
1603                ## User
1604
1605                Hi
1606
1607                ## Assistant
1608
1609                Ahoy!
1610            "}
1611        );
1612
1613        assert_eq!(
1614            thread.latest_token_usage(),
1615            Some(acp_thread::TokenUsage {
1616                used_tokens: 40_000 + 20_000,
1617                max_tokens: 1_000_000,
1618            })
1619        );
1620    });
1621}
1622
1623#[gpui::test]
1624async fn test_truncate_second_message(cx: &mut TestAppContext) {
1625    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1626    let fake_model = model.as_fake();
1627
1628    thread
1629        .update(cx, |thread, cx| {
1630            thread.send(UserMessageId::new(), ["Message 1"], cx)
1631        })
1632        .unwrap();
1633    cx.run_until_parked();
1634    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1635    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1636        language_model::TokenUsage {
1637            input_tokens: 32_000,
1638            output_tokens: 16_000,
1639            cache_creation_input_tokens: 0,
1640            cache_read_input_tokens: 0,
1641        },
1642    ));
1643    fake_model.end_last_completion_stream();
1644    cx.run_until_parked();
1645
1646    let assert_first_message_state = |cx: &mut TestAppContext| {
1647        thread.clone().read_with(cx, |thread, _| {
1648            assert_eq!(
1649                thread.to_markdown(),
1650                indoc! {"
1651                    ## User
1652
1653                    Message 1
1654
1655                    ## Assistant
1656
1657                    Message 1 response
1658                "}
1659            );
1660
1661            assert_eq!(
1662                thread.latest_token_usage(),
1663                Some(acp_thread::TokenUsage {
1664                    used_tokens: 32_000 + 16_000,
1665                    max_tokens: 1_000_000,
1666                })
1667            );
1668        });
1669    };
1670
1671    assert_first_message_state(cx);
1672
1673    let second_message_id = UserMessageId::new();
1674    thread
1675        .update(cx, |thread, cx| {
1676            thread.send(second_message_id.clone(), ["Message 2"], cx)
1677        })
1678        .unwrap();
1679    cx.run_until_parked();
1680
1681    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1682    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1683        language_model::TokenUsage {
1684            input_tokens: 40_000,
1685            output_tokens: 20_000,
1686            cache_creation_input_tokens: 0,
1687            cache_read_input_tokens: 0,
1688        },
1689    ));
1690    fake_model.end_last_completion_stream();
1691    cx.run_until_parked();
1692
1693    thread.read_with(cx, |thread, _| {
1694        assert_eq!(
1695            thread.to_markdown(),
1696            indoc! {"
1697                ## User
1698
1699                Message 1
1700
1701                ## Assistant
1702
1703                Message 1 response
1704
1705                ## User
1706
1707                Message 2
1708
1709                ## Assistant
1710
1711                Message 2 response
1712            "}
1713        );
1714
1715        assert_eq!(
1716            thread.latest_token_usage(),
1717            Some(acp_thread::TokenUsage {
1718                used_tokens: 40_000 + 20_000,
1719                max_tokens: 1_000_000,
1720            })
1721        );
1722    });
1723
1724    thread
1725        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1726        .unwrap();
1727    cx.run_until_parked();
1728
1729    assert_first_message_state(cx);
1730}
1731
1732#[gpui::test]
1733async fn test_title_generation(cx: &mut TestAppContext) {
1734    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1735    let fake_model = model.as_fake();
1736
1737    let summary_model = Arc::new(FakeLanguageModel::default());
1738    thread.update(cx, |thread, cx| {
1739        thread.set_summarization_model(Some(summary_model.clone()), cx)
1740    });
1741
1742    let send = thread
1743        .update(cx, |thread, cx| {
1744            thread.send(UserMessageId::new(), ["Hello"], cx)
1745        })
1746        .unwrap();
1747    cx.run_until_parked();
1748
1749    fake_model.send_last_completion_stream_text_chunk("Hey!");
1750    fake_model.end_last_completion_stream();
1751    cx.run_until_parked();
1752    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1753
1754    // Ensure the summary model has been invoked to generate a title.
1755    summary_model.send_last_completion_stream_text_chunk("Hello ");
1756    summary_model.send_last_completion_stream_text_chunk("world\nG");
1757    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1758    summary_model.end_last_completion_stream();
1759    send.collect::<Vec<_>>().await;
1760    cx.run_until_parked();
1761    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1762
1763    // Send another message, ensuring no title is generated this time.
1764    let send = thread
1765        .update(cx, |thread, cx| {
1766            thread.send(UserMessageId::new(), ["Hello again"], cx)
1767        })
1768        .unwrap();
1769    cx.run_until_parked();
1770    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1771    fake_model.end_last_completion_stream();
1772    cx.run_until_parked();
1773    assert_eq!(summary_model.pending_completions(), Vec::new());
1774    send.collect::<Vec<_>>().await;
1775    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1776}
1777
1778#[gpui::test]
1779async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1780    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1781    let fake_model = model.as_fake();
1782
1783    let _events = thread
1784        .update(cx, |thread, cx| {
1785            thread.add_tool(ToolRequiringPermission);
1786            thread.add_tool(EchoTool);
1787            thread.send(UserMessageId::new(), ["Hey!"], cx)
1788        })
1789        .unwrap();
1790    cx.run_until_parked();
1791
1792    let permission_tool_use = LanguageModelToolUse {
1793        id: "tool_id_1".into(),
1794        name: ToolRequiringPermission::name().into(),
1795        raw_input: "{}".into(),
1796        input: json!({}),
1797        is_input_complete: true,
1798        thought_signature: None,
1799    };
1800    let echo_tool_use = LanguageModelToolUse {
1801        id: "tool_id_2".into(),
1802        name: EchoTool::name().into(),
1803        raw_input: json!({"text": "test"}).to_string(),
1804        input: json!({"text": "test"}),
1805        is_input_complete: true,
1806        thought_signature: None,
1807    };
1808    fake_model.send_last_completion_stream_text_chunk("Hi!");
1809    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1810        permission_tool_use,
1811    ));
1812    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1813        echo_tool_use.clone(),
1814    ));
1815    fake_model.end_last_completion_stream();
1816    cx.run_until_parked();
1817
1818    // Ensure pending tools are skipped when building a request.
1819    let request = thread
1820        .read_with(cx, |thread, cx| {
1821            thread.build_completion_request(CompletionIntent::EditFile, cx)
1822        })
1823        .unwrap();
1824    assert_eq!(
1825        request.messages[1..],
1826        vec![
1827            LanguageModelRequestMessage {
1828                role: Role::User,
1829                content: vec!["Hey!".into()],
1830                cache: true
1831            },
1832            LanguageModelRequestMessage {
1833                role: Role::Assistant,
1834                content: vec![
1835                    MessageContent::Text("Hi!".into()),
1836                    MessageContent::ToolUse(echo_tool_use.clone())
1837                ],
1838                cache: false
1839            },
1840            LanguageModelRequestMessage {
1841                role: Role::User,
1842                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1843                    tool_use_id: echo_tool_use.id.clone(),
1844                    tool_name: echo_tool_use.name,
1845                    is_error: false,
1846                    content: "test".into(),
1847                    output: Some("test".into())
1848                })],
1849                cache: false
1850            },
1851        ],
1852    );
1853}
1854
1855#[gpui::test]
1856async fn test_agent_connection(cx: &mut TestAppContext) {
1857    cx.update(settings::init);
1858    let templates = Templates::new();
1859
1860    // Initialize language model system with test provider
1861    cx.update(|cx| {
1862        gpui_tokio::init(cx);
1863
1864        let http_client = FakeHttpClient::with_404_response();
1865        let clock = Arc::new(clock::FakeSystemClock::new());
1866        let client = Client::new(clock, http_client, cx);
1867        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868        language_model::init(client.clone(), cx);
1869        language_models::init(user_store, client.clone(), cx);
1870        LanguageModelRegistry::test(cx);
1871    });
1872    cx.executor().forbid_parking();
1873
1874    // Create a project for new_thread
1875    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1876    fake_fs.insert_tree(path!("/test"), json!({})).await;
1877    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1878    let cwd = Path::new("/test");
1879    let text_thread_store =
1880        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1881    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1882
1883    // Create agent and connection
1884    let agent = NativeAgent::new(
1885        project.clone(),
1886        history_store,
1887        templates.clone(),
1888        None,
1889        fake_fs.clone(),
1890        &mut cx.to_async(),
1891    )
1892    .await
1893    .unwrap();
1894    let connection = NativeAgentConnection(agent.clone());
1895
1896    // Create a thread using new_thread
1897    let connection_rc = Rc::new(connection.clone());
1898    let acp_thread = cx
1899        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1900        .await
1901        .expect("new_thread should succeed");
1902
1903    // Get the session_id from the AcpThread
1904    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1905
1906    // Test model_selector returns Some
1907    let selector_opt = connection.model_selector(&session_id);
1908    assert!(
1909        selector_opt.is_some(),
1910        "agent should always support ModelSelector"
1911    );
1912    let selector = selector_opt.unwrap();
1913
1914    // Test list_models
1915    let listed_models = cx
1916        .update(|cx| selector.list_models(cx))
1917        .await
1918        .expect("list_models should succeed");
1919    let AgentModelList::Grouped(listed_models) = listed_models else {
1920        panic!("Unexpected model list type");
1921    };
1922    assert!(!listed_models.is_empty(), "should have at least one model");
1923    assert_eq!(
1924        listed_models[&AgentModelGroupName("Fake".into())][0]
1925            .id
1926            .0
1927            .as_ref(),
1928        "fake/fake"
1929    );
1930
1931    // Test selected_model returns the default
1932    let model = cx
1933        .update(|cx| selector.selected_model(cx))
1934        .await
1935        .expect("selected_model should succeed");
1936    let model = cx
1937        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1938        .unwrap();
1939    let model = model.as_fake();
1940    assert_eq!(model.id().0, "fake", "should return default model");
1941
1942    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1943    cx.run_until_parked();
1944    model.send_last_completion_stream_text_chunk("def");
1945    cx.run_until_parked();
1946    acp_thread.read_with(cx, |thread, cx| {
1947        assert_eq!(
1948            thread.to_markdown(cx),
1949            indoc! {"
1950                ## User
1951
1952                abc
1953
1954                ## Assistant
1955
1956                def
1957
1958            "}
1959        )
1960    });
1961
1962    // Test cancel
1963    cx.update(|cx| connection.cancel(&session_id, cx));
1964    request.await.expect("prompt should fail gracefully");
1965
1966    // Ensure that dropping the ACP thread causes the native thread to be
1967    // dropped as well.
1968    cx.update(|_| drop(acp_thread));
1969    let result = cx
1970        .update(|cx| {
1971            connection.prompt(
1972                Some(acp_thread::UserMessageId::new()),
1973                acp::PromptRequest {
1974                    session_id: session_id.clone(),
1975                    prompt: vec!["ghi".into()],
1976                    meta: None,
1977                },
1978                cx,
1979            )
1980        })
1981        .await;
1982    assert_eq!(
1983        result.as_ref().unwrap_err().to_string(),
1984        "Session not found",
1985        "unexpected result: {:?}",
1986        result
1987    );
1988}
1989
1990#[gpui::test]
1991async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1992    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1993    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1994    let fake_model = model.as_fake();
1995
1996    let mut events = thread
1997        .update(cx, |thread, cx| {
1998            thread.send(UserMessageId::new(), ["Think"], cx)
1999        })
2000        .unwrap();
2001    cx.run_until_parked();
2002
2003    // Simulate streaming partial input.
2004    let input = json!({});
2005    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2006        LanguageModelToolUse {
2007            id: "1".into(),
2008            name: ThinkingTool::name().into(),
2009            raw_input: input.to_string(),
2010            input,
2011            is_input_complete: false,
2012            thought_signature: None,
2013        },
2014    ));
2015
2016    // Input streaming completed
2017    let input = json!({ "content": "Thinking hard!" });
2018    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2019        LanguageModelToolUse {
2020            id: "1".into(),
2021            name: "thinking".into(),
2022            raw_input: input.to_string(),
2023            input,
2024            is_input_complete: true,
2025            thought_signature: None,
2026        },
2027    ));
2028    fake_model.end_last_completion_stream();
2029    cx.run_until_parked();
2030
2031    let tool_call = expect_tool_call(&mut events).await;
2032    assert_eq!(
2033        tool_call,
2034        acp::ToolCall {
2035            id: acp::ToolCallId("1".into()),
2036            title: "Thinking".into(),
2037            kind: acp::ToolKind::Think,
2038            status: acp::ToolCallStatus::Pending,
2039            content: vec![],
2040            locations: vec![],
2041            raw_input: Some(json!({})),
2042            raw_output: None,
2043            meta: Some(json!({ "tool_name": "thinking" })),
2044        }
2045    );
2046    let update = expect_tool_call_update_fields(&mut events).await;
2047    assert_eq!(
2048        update,
2049        acp::ToolCallUpdate {
2050            id: acp::ToolCallId("1".into()),
2051            fields: acp::ToolCallUpdateFields {
2052                title: Some("Thinking".into()),
2053                kind: Some(acp::ToolKind::Think),
2054                raw_input: Some(json!({ "content": "Thinking hard!" })),
2055                ..Default::default()
2056            },
2057            meta: None,
2058        }
2059    );
2060    let update = expect_tool_call_update_fields(&mut events).await;
2061    assert_eq!(
2062        update,
2063        acp::ToolCallUpdate {
2064            id: acp::ToolCallId("1".into()),
2065            fields: acp::ToolCallUpdateFields {
2066                status: Some(acp::ToolCallStatus::InProgress),
2067                ..Default::default()
2068            },
2069            meta: None,
2070        }
2071    );
2072    let update = expect_tool_call_update_fields(&mut events).await;
2073    assert_eq!(
2074        update,
2075        acp::ToolCallUpdate {
2076            id: acp::ToolCallId("1".into()),
2077            fields: acp::ToolCallUpdateFields {
2078                content: Some(vec!["Thinking hard!".into()]),
2079                ..Default::default()
2080            },
2081            meta: None,
2082        }
2083    );
2084    let update = expect_tool_call_update_fields(&mut events).await;
2085    assert_eq!(
2086        update,
2087        acp::ToolCallUpdate {
2088            id: acp::ToolCallId("1".into()),
2089            fields: acp::ToolCallUpdateFields {
2090                status: Some(acp::ToolCallStatus::Completed),
2091                raw_output: Some("Finished thinking.".into()),
2092                ..Default::default()
2093            },
2094            meta: None,
2095        }
2096    );
2097}
2098
2099#[gpui::test]
2100async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2101    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2102    let fake_model = model.as_fake();
2103
2104    let mut events = thread
2105        .update(cx, |thread, cx| {
2106            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2107            thread.send(UserMessageId::new(), ["Hello!"], cx)
2108        })
2109        .unwrap();
2110    cx.run_until_parked();
2111
2112    fake_model.send_last_completion_stream_text_chunk("Hey!");
2113    fake_model.end_last_completion_stream();
2114
2115    let mut retry_events = Vec::new();
2116    while let Some(Ok(event)) = events.next().await {
2117        match event {
2118            ThreadEvent::Retry(retry_status) => {
2119                retry_events.push(retry_status);
2120            }
2121            ThreadEvent::Stop(..) => break,
2122            _ => {}
2123        }
2124    }
2125
2126    assert_eq!(retry_events.len(), 0);
2127    thread.read_with(cx, |thread, _cx| {
2128        assert_eq!(
2129            thread.to_markdown(),
2130            indoc! {"
2131                ## User
2132
2133                Hello!
2134
2135                ## Assistant
2136
2137                Hey!
2138            "}
2139        )
2140    });
2141}
2142
2143#[gpui::test]
2144async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2145    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2146    let fake_model = model.as_fake();
2147
2148    let mut events = thread
2149        .update(cx, |thread, cx| {
2150            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2151            thread.send(UserMessageId::new(), ["Hello!"], cx)
2152        })
2153        .unwrap();
2154    cx.run_until_parked();
2155
2156    fake_model.send_last_completion_stream_text_chunk("Hey,");
2157    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2158        provider: LanguageModelProviderName::new("Anthropic"),
2159        retry_after: Some(Duration::from_secs(3)),
2160    });
2161    fake_model.end_last_completion_stream();
2162
2163    cx.executor().advance_clock(Duration::from_secs(3));
2164    cx.run_until_parked();
2165
2166    fake_model.send_last_completion_stream_text_chunk("there!");
2167    fake_model.end_last_completion_stream();
2168    cx.run_until_parked();
2169
2170    let mut retry_events = Vec::new();
2171    while let Some(Ok(event)) = events.next().await {
2172        match event {
2173            ThreadEvent::Retry(retry_status) => {
2174                retry_events.push(retry_status);
2175            }
2176            ThreadEvent::Stop(..) => break,
2177            _ => {}
2178        }
2179    }
2180
2181    assert_eq!(retry_events.len(), 1);
2182    assert!(matches!(
2183        retry_events[0],
2184        acp_thread::RetryStatus { attempt: 1, .. }
2185    ));
2186    thread.read_with(cx, |thread, _cx| {
2187        assert_eq!(
2188            thread.to_markdown(),
2189            indoc! {"
2190                ## User
2191
2192                Hello!
2193
2194                ## Assistant
2195
2196                Hey,
2197
2198                [resume]
2199
2200                ## Assistant
2201
2202                there!
2203            "}
2204        )
2205    });
2206}
2207
2208#[gpui::test]
2209async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2210    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2211    let fake_model = model.as_fake();
2212
2213    let events = thread
2214        .update(cx, |thread, cx| {
2215            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2216            thread.add_tool(EchoTool);
2217            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2218        })
2219        .unwrap();
2220    cx.run_until_parked();
2221
2222    let tool_use_1 = LanguageModelToolUse {
2223        id: "tool_1".into(),
2224        name: EchoTool::name().into(),
2225        raw_input: json!({"text": "test"}).to_string(),
2226        input: json!({"text": "test"}),
2227        is_input_complete: true,
2228        thought_signature: None,
2229    };
2230    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2231        tool_use_1.clone(),
2232    ));
2233    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2234        provider: LanguageModelProviderName::new("Anthropic"),
2235        retry_after: Some(Duration::from_secs(3)),
2236    });
2237    fake_model.end_last_completion_stream();
2238
2239    cx.executor().advance_clock(Duration::from_secs(3));
2240    let completion = fake_model.pending_completions().pop().unwrap();
2241    assert_eq!(
2242        completion.messages[1..],
2243        vec![
2244            LanguageModelRequestMessage {
2245                role: Role::User,
2246                content: vec!["Call the echo tool!".into()],
2247                cache: false
2248            },
2249            LanguageModelRequestMessage {
2250                role: Role::Assistant,
2251                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2252                cache: false
2253            },
2254            LanguageModelRequestMessage {
2255                role: Role::User,
2256                content: vec![language_model::MessageContent::ToolResult(
2257                    LanguageModelToolResult {
2258                        tool_use_id: tool_use_1.id.clone(),
2259                        tool_name: tool_use_1.name.clone(),
2260                        is_error: false,
2261                        content: "test".into(),
2262                        output: Some("test".into())
2263                    }
2264                )],
2265                cache: true
2266            },
2267        ]
2268    );
2269
2270    fake_model.send_last_completion_stream_text_chunk("Done");
2271    fake_model.end_last_completion_stream();
2272    cx.run_until_parked();
2273    events.collect::<Vec<_>>().await;
2274    thread.read_with(cx, |thread, _cx| {
2275        assert_eq!(
2276            thread.last_message(),
2277            Some(Message::Agent(AgentMessage {
2278                content: vec![AgentMessageContent::Text("Done".into())],
2279                tool_results: IndexMap::default()
2280            }))
2281        );
2282    })
2283}
2284
2285#[gpui::test]
2286async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2287    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2288    let fake_model = model.as_fake();
2289
2290    let mut events = thread
2291        .update(cx, |thread, cx| {
2292            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2293            thread.send(UserMessageId::new(), ["Hello!"], cx)
2294        })
2295        .unwrap();
2296    cx.run_until_parked();
2297
2298    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2299        fake_model.send_last_completion_stream_error(
2300            LanguageModelCompletionError::ServerOverloaded {
2301                provider: LanguageModelProviderName::new("Anthropic"),
2302                retry_after: Some(Duration::from_secs(3)),
2303            },
2304        );
2305        fake_model.end_last_completion_stream();
2306        cx.executor().advance_clock(Duration::from_secs(3));
2307        cx.run_until_parked();
2308    }
2309
2310    let mut errors = Vec::new();
2311    let mut retry_events = Vec::new();
2312    while let Some(event) = events.next().await {
2313        match event {
2314            Ok(ThreadEvent::Retry(retry_status)) => {
2315                retry_events.push(retry_status);
2316            }
2317            Ok(ThreadEvent::Stop(..)) => break,
2318            Err(error) => errors.push(error),
2319            _ => {}
2320        }
2321    }
2322
2323    assert_eq!(
2324        retry_events.len(),
2325        crate::thread::MAX_RETRY_ATTEMPTS as usize
2326    );
2327    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2328        assert_eq!(retry_events[i].attempt, i + 1);
2329    }
2330    assert_eq!(errors.len(), 1);
2331    let error = errors[0]
2332        .downcast_ref::<LanguageModelCompletionError>()
2333        .unwrap();
2334    assert!(matches!(
2335        error,
2336        LanguageModelCompletionError::ServerOverloaded { .. }
2337    ));
2338}
2339
2340/// Filters out the stop events for asserting against in tests
2341fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2342    result_events
2343        .into_iter()
2344        .filter_map(|event| match event.unwrap() {
2345            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2346            _ => None,
2347        })
2348        .collect()
2349}
2350
2351struct ThreadTest {
2352    model: Arc<dyn LanguageModel>,
2353    thread: Entity<Thread>,
2354    project_context: Entity<ProjectContext>,
2355    context_server_store: Entity<ContextServerStore>,
2356    fs: Arc<FakeFs>,
2357}
2358
2359enum TestModel {
2360    Sonnet4,
2361    Fake,
2362}
2363
2364impl TestModel {
2365    fn id(&self) -> LanguageModelId {
2366        match self {
2367            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2368            TestModel::Fake => unreachable!(),
2369        }
2370    }
2371}
2372
2373async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2374    cx.executor().allow_parking();
2375
2376    let fs = FakeFs::new(cx.background_executor.clone());
2377    fs.create_dir(paths::settings_file().parent().unwrap())
2378        .await
2379        .unwrap();
2380    fs.insert_file(
2381        paths::settings_file(),
2382        json!({
2383            "agent": {
2384                "default_profile": "test-profile",
2385                "profiles": {
2386                    "test-profile": {
2387                        "name": "Test Profile",
2388                        "tools": {
2389                            EchoTool::name(): true,
2390                            DelayTool::name(): true,
2391                            WordListTool::name(): true,
2392                            ToolRequiringPermission::name(): true,
2393                            InfiniteTool::name(): true,
2394                            ThinkingTool::name(): true,
2395                        }
2396                    }
2397                }
2398            }
2399        })
2400        .to_string()
2401        .into_bytes(),
2402    )
2403    .await;
2404
2405    cx.update(|cx| {
2406        settings::init(cx);
2407
2408        match model {
2409            TestModel::Fake => {}
2410            TestModel::Sonnet4 => {
2411                gpui_tokio::init(cx);
2412                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2413                cx.set_http_client(Arc::new(http_client));
2414                let client = Client::production(cx);
2415                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2416                language_model::init(client.clone(), cx);
2417                language_models::init(user_store, client.clone(), cx);
2418            }
2419        };
2420
2421        watch_settings(fs.clone(), cx);
2422    });
2423
2424    let templates = Templates::new();
2425
2426    fs.insert_tree(path!("/test"), json!({})).await;
2427    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2428
2429    let model = cx
2430        .update(|cx| {
2431            if let TestModel::Fake = model {
2432                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2433            } else {
2434                let model_id = model.id();
2435                let models = LanguageModelRegistry::read_global(cx);
2436                let model = models
2437                    .available_models(cx)
2438                    .find(|model| model.id() == model_id)
2439                    .unwrap();
2440
2441                let provider = models.provider(&model.provider_id()).unwrap();
2442                let authenticated = provider.authenticate(cx);
2443
2444                cx.spawn(async move |_cx| {
2445                    authenticated.await.unwrap();
2446                    model
2447                })
2448            }
2449        })
2450        .await;
2451
2452    let project_context = cx.new(|_cx| ProjectContext::default());
2453    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2454    let context_server_registry =
2455        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2456    let thread = cx.new(|cx| {
2457        Thread::new(
2458            project,
2459            project_context.clone(),
2460            context_server_registry,
2461            templates,
2462            Some(model.clone()),
2463            cx,
2464        )
2465    });
2466    ThreadTest {
2467        model,
2468        thread,
2469        project_context,
2470        context_server_store,
2471        fs,
2472    }
2473}
2474
2475#[cfg(test)]
2476#[ctor::ctor]
2477fn init_logger() {
2478    if std::env::var("RUST_LOG").is_ok() {
2479        env_logger::init();
2480    }
2481}
2482
2483fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2484    let fs = fs.clone();
2485    cx.spawn({
2486        async move |cx| {
2487            let mut new_settings_content_rx = settings::watch_config_file(
2488                cx.background_executor(),
2489                fs,
2490                paths::settings_file().clone(),
2491            );
2492
2493            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2494                cx.update(|cx| {
2495                    SettingsStore::update_global(cx, |settings, cx| {
2496                        settings.set_user_settings(&new_settings_content, cx)
2497                    })
2498                })
2499                .ok();
2500            }
2501        }
2502    })
2503    .detach();
2504}
2505
2506fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2507    completion
2508        .tools
2509        .iter()
2510        .map(|tool| tool.name.clone())
2511        .collect()
2512}
2513
2514fn setup_context_server(
2515    name: &'static str,
2516    tools: Vec<context_server::types::Tool>,
2517    context_server_store: &Entity<ContextServerStore>,
2518    cx: &mut TestAppContext,
2519) -> mpsc::UnboundedReceiver<(
2520    context_server::types::CallToolParams,
2521    oneshot::Sender<context_server::types::CallToolResponse>,
2522)> {
2523    cx.update(|cx| {
2524        let mut settings = ProjectSettings::get_global(cx).clone();
2525        settings.context_servers.insert(
2526            name.into(),
2527            project::project_settings::ContextServerSettings::Custom {
2528                enabled: true,
2529                command: ContextServerCommand {
2530                    path: "somebinary".into(),
2531                    args: Vec::new(),
2532                    env: None,
2533                    timeout: None,
2534                },
2535            },
2536        );
2537        ProjectSettings::override_global(settings, cx);
2538    });
2539
2540    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2541    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2542        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2543            context_server::types::InitializeResponse {
2544                protocol_version: context_server::types::ProtocolVersion(
2545                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2546                ),
2547                server_info: context_server::types::Implementation {
2548                    name: name.into(),
2549                    version: "1.0.0".to_string(),
2550                },
2551                capabilities: context_server::types::ServerCapabilities {
2552                    tools: Some(context_server::types::ToolsCapabilities {
2553                        list_changed: Some(true),
2554                    }),
2555                    ..Default::default()
2556                },
2557                meta: None,
2558            }
2559        })
2560        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2561            let tools = tools.clone();
2562            async move {
2563                context_server::types::ListToolsResponse {
2564                    tools,
2565                    next_cursor: None,
2566                    meta: None,
2567                }
2568            }
2569        })
2570        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2571            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2572            async move {
2573                let (response_tx, response_rx) = oneshot::channel();
2574                mcp_tool_calls_tx
2575                    .unbounded_send((params, response_tx))
2576                    .unwrap();
2577                response_rx.await.unwrap()
2578            }
2579        });
2580    context_server_store.update(cx, |store, cx| {
2581        store.start_server(
2582            Arc::new(ContextServer::new(
2583                ContextServerId(name.into()),
2584                Arc::new(fake_transport),
2585            )),
2586            cx,
2587        );
2588    });
2589    cx.run_until_parked();
2590    mcp_tool_calls_rx
2591}