mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  17    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use pretty_assertions::assert_eq;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use reqwest_client::ReqwestClient;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .unwrap()
  44        .collect()
  45        .await;
  46    thread.update(cx, |thread, _cx| {
  47        assert_eq!(
  48            thread.last_message().unwrap().to_markdown(),
  49            indoc! {"
  50                ## Assistant
  51
  52                Hello
  53            "}
  54        )
  55    });
  56    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  57}
  58
  59#[gpui::test]
  60#[ignore = "can't run on CI yet"]
  61async fn test_thinking(cx: &mut TestAppContext) {
  62    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  63
  64    let events = thread
  65        .update(cx, |thread, cx| {
  66            thread.send(
  67                UserMessageId::new(),
  68                [indoc! {"
  69                    Testing:
  70
  71                    Generate a thinking step where you just think the word 'Think',
  72                    and have your final answer be 'Hello'
  73                "}],
  74                cx,
  75            )
  76        })
  77        .unwrap()
  78        .collect()
  79        .await;
  80    thread.update(cx, |thread, _cx| {
  81        assert_eq!(
  82            thread.last_message().unwrap().to_markdown(),
  83            indoc! {"
  84                ## Assistant
  85
  86                <think>Think</think>
  87                Hello
  88            "}
  89        )
  90    });
  91    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  92}
  93
  94#[gpui::test]
  95async fn test_system_prompt(cx: &mut TestAppContext) {
  96    let ThreadTest {
  97        model,
  98        thread,
  99        project_context,
 100        ..
 101    } = setup(cx, TestModel::Fake).await;
 102    let fake_model = model.as_fake();
 103
 104    project_context.update(cx, |project_context, _cx| {
 105        project_context.shell = "test-shell".into()
 106    });
 107    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 108    thread
 109        .update(cx, |thread, cx| {
 110            thread.send(UserMessageId::new(), ["abc"], cx)
 111        })
 112        .unwrap();
 113    cx.run_until_parked();
 114    let mut pending_completions = fake_model.pending_completions();
 115    assert_eq!(
 116        pending_completions.len(),
 117        1,
 118        "unexpected pending completions: {:?}",
 119        pending_completions
 120    );
 121
 122    let pending_completion = pending_completions.pop().unwrap();
 123    assert_eq!(pending_completion.messages[0].role, Role::System);
 124
 125    let system_message = &pending_completion.messages[0];
 126    let system_prompt = system_message.content[0].to_str().unwrap();
 127    assert!(
 128        system_prompt.contains("test-shell"),
 129        "unexpected system message: {:?}",
 130        system_message
 131    );
 132    assert!(
 133        system_prompt.contains("## Fixing Diagnostics"),
 134        "unexpected system message: {:?}",
 135        system_message
 136    );
 137}
 138
 139#[gpui::test]
 140async fn test_prompt_caching(cx: &mut TestAppContext) {
 141    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 142    let fake_model = model.as_fake();
 143
 144    // Send initial user message and verify it's cached
 145    thread
 146        .update(cx, |thread, cx| {
 147            thread.send(UserMessageId::new(), ["Message 1"], cx)
 148        })
 149        .unwrap();
 150    cx.run_until_parked();
 151
 152    let completion = fake_model.pending_completions().pop().unwrap();
 153    assert_eq!(
 154        completion.messages[1..],
 155        vec![LanguageModelRequestMessage {
 156            role: Role::User,
 157            content: vec!["Message 1".into()],
 158            cache: true
 159        }]
 160    );
 161    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 162        "Response to Message 1".into(),
 163    ));
 164    fake_model.end_last_completion_stream();
 165    cx.run_until_parked();
 166
 167    // Send another user message and verify only the latest is cached
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["Message 2"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174
 175    let completion = fake_model.pending_completions().pop().unwrap();
 176    assert_eq!(
 177        completion.messages[1..],
 178        vec![
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 1".into()],
 182                cache: false
 183            },
 184            LanguageModelRequestMessage {
 185                role: Role::Assistant,
 186                content: vec!["Response to Message 1".into()],
 187                cache: false
 188            },
 189            LanguageModelRequestMessage {
 190                role: Role::User,
 191                content: vec!["Message 2".into()],
 192                cache: true
 193            }
 194        ]
 195    );
 196    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 197        "Response to Message 2".into(),
 198    ));
 199    fake_model.end_last_completion_stream();
 200    cx.run_until_parked();
 201
 202    // Simulate a tool call and verify that the latest tool result is cached
 203    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 204    thread
 205        .update(cx, |thread, cx| {
 206            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 207        })
 208        .unwrap();
 209    cx.run_until_parked();
 210
 211    let tool_use = LanguageModelToolUse {
 212        id: "tool_1".into(),
 213        name: EchoTool.name().into(),
 214        raw_input: json!({"text": "test"}).to_string(),
 215        input: json!({"text": "test"}),
 216        is_input_complete: true,
 217    };
 218    fake_model
 219        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 220    fake_model.end_last_completion_stream();
 221    cx.run_until_parked();
 222
 223    let completion = fake_model.pending_completions().pop().unwrap();
 224    let tool_result = LanguageModelToolResult {
 225        tool_use_id: "tool_1".into(),
 226        tool_name: EchoTool.name().into(),
 227        is_error: false,
 228        content: "test".into(),
 229        output: Some("test".into()),
 230    };
 231    assert_eq!(
 232        completion.messages[1..],
 233        vec![
 234            LanguageModelRequestMessage {
 235                role: Role::User,
 236                content: vec!["Message 1".into()],
 237                cache: false
 238            },
 239            LanguageModelRequestMessage {
 240                role: Role::Assistant,
 241                content: vec!["Response to Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::User,
 246                content: vec!["Message 2".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::Assistant,
 251                content: vec!["Response to Message 2".into()],
 252                cache: false
 253            },
 254            LanguageModelRequestMessage {
 255                role: Role::User,
 256                content: vec!["Use the echo tool".into()],
 257                cache: false
 258            },
 259            LanguageModelRequestMessage {
 260                role: Role::Assistant,
 261                content: vec![MessageContent::ToolUse(tool_use)],
 262                cache: false
 263            },
 264            LanguageModelRequestMessage {
 265                role: Role::User,
 266                content: vec![MessageContent::ToolResult(tool_result)],
 267                cache: true
 268            }
 269        ]
 270    );
 271}
 272
 273#[gpui::test]
 274#[ignore = "can't run on CI yet"]
 275async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 276    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 277
 278    // Test a tool call that's likely to complete *before* streaming stops.
 279    let events = thread
 280        .update(cx, |thread, cx| {
 281            thread.add_tool(EchoTool);
 282            thread.send(
 283                UserMessageId::new(),
 284                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 285                cx,
 286            )
 287        })
 288        .unwrap()
 289        .collect()
 290        .await;
 291    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 292
 293    // Test a tool calls that's likely to complete *after* streaming stops.
 294    let events = thread
 295        .update(cx, |thread, cx| {
 296            thread.remove_tool(&AgentTool::name(&EchoTool));
 297            thread.add_tool(DelayTool);
 298            thread.send(
 299                UserMessageId::new(),
 300                [
 301                    "Now call the delay tool with 200ms.",
 302                    "When the timer goes off, then you echo the output of the tool.",
 303                ],
 304                cx,
 305            )
 306        })
 307        .unwrap()
 308        .collect()
 309        .await;
 310    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 311    thread.update(cx, |thread, _cx| {
 312        assert!(
 313            thread
 314                .last_message()
 315                .unwrap()
 316                .as_agent_message()
 317                .unwrap()
 318                .content
 319                .iter()
 320                .any(|content| {
 321                    if let AgentMessageContent::Text(text) = content {
 322                        text.contains("Ding")
 323                    } else {
 324                        false
 325                    }
 326                }),
 327            "{}",
 328            thread.to_markdown()
 329        );
 330    });
 331}
 332
 333#[gpui::test]
 334#[ignore = "can't run on CI yet"]
 335async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let mut events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(WordListTool);
 342            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 343        })
 344        .unwrap();
 345
 346    let mut saw_partial_tool_use = false;
 347    while let Some(event) = events.next().await {
 348        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 349            thread.update(cx, |thread, _cx| {
 350                // Look for a tool use in the thread's last message
 351                let message = thread.last_message().unwrap();
 352                let agent_message = message.as_agent_message().unwrap();
 353                let last_content = agent_message.content.last().unwrap();
 354                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 355                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 356                    if tool_call.status == acp::ToolCallStatus::Pending {
 357                        if !last_tool_use.is_input_complete
 358                            && last_tool_use.input.get("g").is_none()
 359                        {
 360                            saw_partial_tool_use = true;
 361                        }
 362                    } else {
 363                        last_tool_use
 364                            .input
 365                            .get("a")
 366                            .expect("'a' has streamed because input is now complete");
 367                        last_tool_use
 368                            .input
 369                            .get("g")
 370                            .expect("'g' has streamed because input is now complete");
 371                    }
 372                } else {
 373                    panic!("last content should be a tool use");
 374                }
 375            });
 376        }
 377    }
 378
 379    assert!(
 380        saw_partial_tool_use,
 381        "should see at least one partially streamed tool use in the history"
 382    );
 383}
 384
 385#[gpui::test]
 386async fn test_tool_authorization(cx: &mut TestAppContext) {
 387    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 388    let fake_model = model.as_fake();
 389
 390    let mut events = thread
 391        .update(cx, |thread, cx| {
 392            thread.add_tool(ToolRequiringPermission);
 393            thread.send(UserMessageId::new(), ["abc"], cx)
 394        })
 395        .unwrap();
 396    cx.run_until_parked();
 397    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 398        LanguageModelToolUse {
 399            id: "tool_id_1".into(),
 400            name: ToolRequiringPermission.name().into(),
 401            raw_input: "{}".into(),
 402            input: json!({}),
 403            is_input_complete: true,
 404        },
 405    ));
 406    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 407        LanguageModelToolUse {
 408            id: "tool_id_2".into(),
 409            name: ToolRequiringPermission.name().into(),
 410            raw_input: "{}".into(),
 411            input: json!({}),
 412            is_input_complete: true,
 413        },
 414    ));
 415    fake_model.end_last_completion_stream();
 416    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 417    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 418
 419    // Approve the first
 420    tool_call_auth_1
 421        .response
 422        .send(tool_call_auth_1.options[1].id.clone())
 423        .unwrap();
 424    cx.run_until_parked();
 425
 426    // Reject the second
 427    tool_call_auth_2
 428        .response
 429        .send(tool_call_auth_1.options[2].id.clone())
 430        .unwrap();
 431    cx.run_until_parked();
 432
 433    let completion = fake_model.pending_completions().pop().unwrap();
 434    let message = completion.messages.last().unwrap();
 435    assert_eq!(
 436        message.content,
 437        vec![
 438            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 439                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 440                tool_name: ToolRequiringPermission.name().into(),
 441                is_error: false,
 442                content: "Allowed".into(),
 443                output: Some("Allowed".into())
 444            }),
 445            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 446                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 447                tool_name: ToolRequiringPermission.name().into(),
 448                is_error: true,
 449                content: "Permission to run tool denied by user".into(),
 450                output: None
 451            })
 452        ]
 453    );
 454
 455    // Simulate yet another tool call.
 456    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 457        LanguageModelToolUse {
 458            id: "tool_id_3".into(),
 459            name: ToolRequiringPermission.name().into(),
 460            raw_input: "{}".into(),
 461            input: json!({}),
 462            is_input_complete: true,
 463        },
 464    ));
 465    fake_model.end_last_completion_stream();
 466
 467    // Respond by always allowing tools.
 468    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 469    tool_call_auth_3
 470        .response
 471        .send(tool_call_auth_3.options[0].id.clone())
 472        .unwrap();
 473    cx.run_until_parked();
 474    let completion = fake_model.pending_completions().pop().unwrap();
 475    let message = completion.messages.last().unwrap();
 476    assert_eq!(
 477        message.content,
 478        vec![language_model::MessageContent::ToolResult(
 479            LanguageModelToolResult {
 480                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 481                tool_name: ToolRequiringPermission.name().into(),
 482                is_error: false,
 483                content: "Allowed".into(),
 484                output: Some("Allowed".into())
 485            }
 486        )]
 487    );
 488
 489    // Simulate a final tool call, ensuring we don't trigger authorization.
 490    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 491        LanguageModelToolUse {
 492            id: "tool_id_4".into(),
 493            name: ToolRequiringPermission.name().into(),
 494            raw_input: "{}".into(),
 495            input: json!({}),
 496            is_input_complete: true,
 497        },
 498    ));
 499    fake_model.end_last_completion_stream();
 500    cx.run_until_parked();
 501    let completion = fake_model.pending_completions().pop().unwrap();
 502    let message = completion.messages.last().unwrap();
 503    assert_eq!(
 504        message.content,
 505        vec![language_model::MessageContent::ToolResult(
 506            LanguageModelToolResult {
 507                tool_use_id: "tool_id_4".into(),
 508                tool_name: ToolRequiringPermission.name().into(),
 509                is_error: false,
 510                content: "Allowed".into(),
 511                output: Some("Allowed".into())
 512            }
 513        )]
 514    );
 515}
 516
 517#[gpui::test]
 518async fn test_tool_hallucination(cx: &mut TestAppContext) {
 519    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 520    let fake_model = model.as_fake();
 521
 522    let mut events = thread
 523        .update(cx, |thread, cx| {
 524            thread.send(UserMessageId::new(), ["abc"], cx)
 525        })
 526        .unwrap();
 527    cx.run_until_parked();
 528    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 529        LanguageModelToolUse {
 530            id: "tool_id_1".into(),
 531            name: "nonexistent_tool".into(),
 532            raw_input: "{}".into(),
 533            input: json!({}),
 534            is_input_complete: true,
 535        },
 536    ));
 537    fake_model.end_last_completion_stream();
 538
 539    let tool_call = expect_tool_call(&mut events).await;
 540    assert_eq!(tool_call.title, "nonexistent_tool");
 541    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 542    let update = expect_tool_call_update_fields(&mut events).await;
 543    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 544}
 545
 546#[gpui::test]
 547async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 548    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 549    let fake_model = model.as_fake();
 550
 551    let events = thread
 552        .update(cx, |thread, cx| {
 553            thread.add_tool(EchoTool);
 554            thread.send(UserMessageId::new(), ["abc"], cx)
 555        })
 556        .unwrap();
 557    cx.run_until_parked();
 558    let tool_use = LanguageModelToolUse {
 559        id: "tool_id_1".into(),
 560        name: EchoTool.name().into(),
 561        raw_input: "{}".into(),
 562        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 563        is_input_complete: true,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568
 569    cx.run_until_parked();
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_id_1".into(),
 573        tool_name: EchoTool.name().into(),
 574        is_error: false,
 575        content: "def".into(),
 576        output: Some("def".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["abc".into()],
 584                cache: false
 585            },
 586            LanguageModelRequestMessage {
 587                role: Role::Assistant,
 588                content: vec![MessageContent::ToolUse(tool_use.clone())],
 589                cache: false
 590            },
 591            LanguageModelRequestMessage {
 592                role: Role::User,
 593                content: vec![MessageContent::ToolResult(tool_result.clone())],
 594                cache: true
 595            },
 596        ]
 597    );
 598
 599    // Simulate reaching tool use limit.
 600    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 601        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 602    ));
 603    fake_model.end_last_completion_stream();
 604    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 605    assert!(
 606        last_event
 607            .unwrap_err()
 608            .is::<language_model::ToolUseLimitReachedError>()
 609    );
 610
 611    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 612    cx.run_until_parked();
 613    let completion = fake_model.pending_completions().pop().unwrap();
 614    assert_eq!(
 615        completion.messages[1..],
 616        vec![
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec!["abc".into()],
 620                cache: false
 621            },
 622            LanguageModelRequestMessage {
 623                role: Role::Assistant,
 624                content: vec![MessageContent::ToolUse(tool_use)],
 625                cache: false
 626            },
 627            LanguageModelRequestMessage {
 628                role: Role::User,
 629                content: vec![MessageContent::ToolResult(tool_result)],
 630                cache: false
 631            },
 632            LanguageModelRequestMessage {
 633                role: Role::User,
 634                content: vec!["Continue where you left off".into()],
 635                cache: true
 636            }
 637        ]
 638    );
 639
 640    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 641    fake_model.end_last_completion_stream();
 642    events.collect::<Vec<_>>().await;
 643    thread.read_with(cx, |thread, _cx| {
 644        assert_eq!(
 645            thread.last_message().unwrap().to_markdown(),
 646            indoc! {"
 647                ## Assistant
 648
 649                Done
 650            "}
 651        )
 652    });
 653
 654    // Ensure we error if calling resume when tool use limit was *not* reached.
 655    let error = thread
 656        .update(cx, |thread, cx| thread.resume(cx))
 657        .unwrap_err();
 658    assert_eq!(
 659        error.to_string(),
 660        "can only resume after tool use limit is reached"
 661    )
 662}
 663
 664#[gpui::test]
 665async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 666    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 667    let fake_model = model.as_fake();
 668
 669    let events = thread
 670        .update(cx, |thread, cx| {
 671            thread.add_tool(EchoTool);
 672            thread.send(UserMessageId::new(), ["abc"], cx)
 673        })
 674        .unwrap();
 675    cx.run_until_parked();
 676
 677    let tool_use = LanguageModelToolUse {
 678        id: "tool_id_1".into(),
 679        name: EchoTool.name().into(),
 680        raw_input: "{}".into(),
 681        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 682        is_input_complete: true,
 683    };
 684    let tool_result = LanguageModelToolResult {
 685        tool_use_id: "tool_id_1".into(),
 686        tool_name: EchoTool.name().into(),
 687        is_error: false,
 688        content: "def".into(),
 689        output: Some("def".into()),
 690    };
 691    fake_model
 692        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 693    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 694        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 695    ));
 696    fake_model.end_last_completion_stream();
 697    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 698    assert!(
 699        last_event
 700            .unwrap_err()
 701            .is::<language_model::ToolUseLimitReachedError>()
 702    );
 703
 704    thread
 705        .update(cx, |thread, cx| {
 706            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 707        })
 708        .unwrap();
 709    cx.run_until_parked();
 710    let completion = fake_model.pending_completions().pop().unwrap();
 711    assert_eq!(
 712        completion.messages[1..],
 713        vec![
 714            LanguageModelRequestMessage {
 715                role: Role::User,
 716                content: vec!["abc".into()],
 717                cache: false
 718            },
 719            LanguageModelRequestMessage {
 720                role: Role::Assistant,
 721                content: vec![MessageContent::ToolUse(tool_use)],
 722                cache: false
 723            },
 724            LanguageModelRequestMessage {
 725                role: Role::User,
 726                content: vec![MessageContent::ToolResult(tool_result)],
 727                cache: false
 728            },
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["ghi".into()],
 732                cache: true
 733            }
 734        ]
 735    );
 736}
 737
 738async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 739    let event = events
 740        .next()
 741        .await
 742        .expect("no tool call authorization event received")
 743        .unwrap();
 744    match event {
 745        ThreadEvent::ToolCall(tool_call) => return tool_call,
 746        event => {
 747            panic!("Unexpected event {event:?}");
 748        }
 749    }
 750}
 751
 752async fn expect_tool_call_update_fields(
 753    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 754) -> acp::ToolCallUpdate {
 755    let event = events
 756        .next()
 757        .await
 758        .expect("no tool call authorization event received")
 759        .unwrap();
 760    match event {
 761        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 762            return update;
 763        }
 764        event => {
 765            panic!("Unexpected event {event:?}");
 766        }
 767    }
 768}
 769
 770async fn next_tool_call_authorization(
 771    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 772) -> ToolCallAuthorization {
 773    loop {
 774        let event = events
 775            .next()
 776            .await
 777            .expect("no tool call authorization event received")
 778            .unwrap();
 779        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 780            let permission_kinds = tool_call_authorization
 781                .options
 782                .iter()
 783                .map(|o| o.kind)
 784                .collect::<Vec<_>>();
 785            assert_eq!(
 786                permission_kinds,
 787                vec![
 788                    acp::PermissionOptionKind::AllowAlways,
 789                    acp::PermissionOptionKind::AllowOnce,
 790                    acp::PermissionOptionKind::RejectOnce,
 791                ]
 792            );
 793            return tool_call_authorization;
 794        }
 795    }
 796}
 797
 798#[gpui::test]
 799#[ignore = "can't run on CI yet"]
 800async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 801    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 802
 803    // Test concurrent tool calls with different delay times
 804    let events = thread
 805        .update(cx, |thread, cx| {
 806            thread.add_tool(DelayTool);
 807            thread.send(
 808                UserMessageId::new(),
 809                [
 810                    "Call the delay tool twice in the same message.",
 811                    "Once with 100ms. Once with 300ms.",
 812                    "When both timers are complete, describe the outputs.",
 813                ],
 814                cx,
 815            )
 816        })
 817        .unwrap()
 818        .collect()
 819        .await;
 820
 821    let stop_reasons = stop_events(events);
 822    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 823
 824    thread.update(cx, |thread, _cx| {
 825        let last_message = thread.last_message().unwrap();
 826        let agent_message = last_message.as_agent_message().unwrap();
 827        let text = agent_message
 828            .content
 829            .iter()
 830            .filter_map(|content| {
 831                if let AgentMessageContent::Text(text) = content {
 832                    Some(text.as_str())
 833                } else {
 834                    None
 835                }
 836            })
 837            .collect::<String>();
 838
 839        assert!(text.contains("Ding"));
 840    });
 841}
 842
 843#[gpui::test]
 844async fn test_profiles(cx: &mut TestAppContext) {
 845    let ThreadTest {
 846        model, thread, fs, ..
 847    } = setup(cx, TestModel::Fake).await;
 848    let fake_model = model.as_fake();
 849
 850    thread.update(cx, |thread, _cx| {
 851        thread.add_tool(DelayTool);
 852        thread.add_tool(EchoTool);
 853        thread.add_tool(InfiniteTool);
 854    });
 855
 856    // Override profiles and wait for settings to be loaded.
 857    fs.insert_file(
 858        paths::settings_file(),
 859        json!({
 860            "agent": {
 861                "profiles": {
 862                    "test-1": {
 863                        "name": "Test Profile 1",
 864                        "tools": {
 865                            EchoTool.name(): true,
 866                            DelayTool.name(): true,
 867                        }
 868                    },
 869                    "test-2": {
 870                        "name": "Test Profile 2",
 871                        "tools": {
 872                            InfiniteTool.name(): true,
 873                        }
 874                    }
 875                }
 876            }
 877        })
 878        .to_string()
 879        .into_bytes(),
 880    )
 881    .await;
 882    cx.run_until_parked();
 883
 884    // Test that test-1 profile (default) has echo and delay tools
 885    thread
 886        .update(cx, |thread, cx| {
 887            thread.set_profile(AgentProfileId("test-1".into()));
 888            thread.send(UserMessageId::new(), ["test"], cx)
 889        })
 890        .unwrap();
 891    cx.run_until_parked();
 892
 893    let mut pending_completions = fake_model.pending_completions();
 894    assert_eq!(pending_completions.len(), 1);
 895    let completion = pending_completions.pop().unwrap();
 896    let tool_names: Vec<String> = completion
 897        .tools
 898        .iter()
 899        .map(|tool| tool.name.clone())
 900        .collect();
 901    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 902    fake_model.end_last_completion_stream();
 903
 904    // Switch to test-2 profile, and verify that it has only the infinite tool.
 905    thread
 906        .update(cx, |thread, cx| {
 907            thread.set_profile(AgentProfileId("test-2".into()));
 908            thread.send(UserMessageId::new(), ["test2"], cx)
 909        })
 910        .unwrap();
 911    cx.run_until_parked();
 912    let mut pending_completions = fake_model.pending_completions();
 913    assert_eq!(pending_completions.len(), 1);
 914    let completion = pending_completions.pop().unwrap();
 915    let tool_names: Vec<String> = completion
 916        .tools
 917        .iter()
 918        .map(|tool| tool.name.clone())
 919        .collect();
 920    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 921}
 922
 923#[gpui::test]
 924#[ignore = "can't run on CI yet"]
 925async fn test_cancellation(cx: &mut TestAppContext) {
 926    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 927
 928    let mut events = thread
 929        .update(cx, |thread, cx| {
 930            thread.add_tool(InfiniteTool);
 931            thread.add_tool(EchoTool);
 932            thread.send(
 933                UserMessageId::new(),
 934                ["Call the echo tool, then call the infinite tool, then explain their output"],
 935                cx,
 936            )
 937        })
 938        .unwrap();
 939
 940    // Wait until both tools are called.
 941    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 942    let mut echo_id = None;
 943    let mut echo_completed = false;
 944    while let Some(event) = events.next().await {
 945        match event.unwrap() {
 946            ThreadEvent::ToolCall(tool_call) => {
 947                assert_eq!(tool_call.title, expected_tools.remove(0));
 948                if tool_call.title == "Echo" {
 949                    echo_id = Some(tool_call.id);
 950                }
 951            }
 952            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 953                acp::ToolCallUpdate {
 954                    id,
 955                    fields:
 956                        acp::ToolCallUpdateFields {
 957                            status: Some(acp::ToolCallStatus::Completed),
 958                            ..
 959                        },
 960                },
 961            )) if Some(&id) == echo_id.as_ref() => {
 962                echo_completed = true;
 963            }
 964            _ => {}
 965        }
 966
 967        if expected_tools.is_empty() && echo_completed {
 968            break;
 969        }
 970    }
 971
 972    // Cancel the current send and ensure that the event stream is closed, even
 973    // if one of the tools is still running.
 974    thread.update(cx, |thread, cx| thread.cancel(cx));
 975    let events = events.collect::<Vec<_>>().await;
 976    let last_event = events.last();
 977    assert!(
 978        matches!(
 979            last_event,
 980            Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
 981        ),
 982        "unexpected event {last_event:?}"
 983    );
 984
 985    // Ensure we can still send a new message after cancellation.
 986    let events = thread
 987        .update(cx, |thread, cx| {
 988            thread.send(
 989                UserMessageId::new(),
 990                ["Testing: reply with 'Hello' then stop."],
 991                cx,
 992            )
 993        })
 994        .unwrap()
 995        .collect::<Vec<_>>()
 996        .await;
 997    thread.update(cx, |thread, _cx| {
 998        let message = thread.last_message().unwrap();
 999        let agent_message = message.as_agent_message().unwrap();
1000        assert_eq!(
1001            agent_message.content,
1002            vec![AgentMessageContent::Text("Hello".to_string())]
1003        );
1004    });
1005    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1006}
1007
1008#[gpui::test]
1009async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1010    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1011    let fake_model = model.as_fake();
1012
1013    let events_1 = thread
1014        .update(cx, |thread, cx| {
1015            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1016        })
1017        .unwrap();
1018    cx.run_until_parked();
1019    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1020    cx.run_until_parked();
1021
1022    let events_2 = thread
1023        .update(cx, |thread, cx| {
1024            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1025        })
1026        .unwrap();
1027    cx.run_until_parked();
1028    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1029    fake_model
1030        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1031    fake_model.end_last_completion_stream();
1032
1033    let events_1 = events_1.collect::<Vec<_>>().await;
1034    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
1035    let events_2 = events_2.collect::<Vec<_>>().await;
1036    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1037}
1038
1039#[gpui::test]
1040async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1041    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1042    let fake_model = model.as_fake();
1043
1044    let events_1 = thread
1045        .update(cx, |thread, cx| {
1046            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1047        })
1048        .unwrap();
1049    cx.run_until_parked();
1050    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1051    fake_model
1052        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1053    fake_model.end_last_completion_stream();
1054    let events_1 = events_1.collect::<Vec<_>>().await;
1055
1056    let events_2 = thread
1057        .update(cx, |thread, cx| {
1058            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1059        })
1060        .unwrap();
1061    cx.run_until_parked();
1062    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1063    fake_model
1064        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1065    fake_model.end_last_completion_stream();
1066    let events_2 = events_2.collect::<Vec<_>>().await;
1067
1068    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1069    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1070}
1071
1072#[gpui::test]
1073async fn test_refusal(cx: &mut TestAppContext) {
1074    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1075    let fake_model = model.as_fake();
1076
1077    let events = thread
1078        .update(cx, |thread, cx| {
1079            thread.send(UserMessageId::new(), ["Hello"], cx)
1080        })
1081        .unwrap();
1082    cx.run_until_parked();
1083    thread.read_with(cx, |thread, _| {
1084        assert_eq!(
1085            thread.to_markdown(),
1086            indoc! {"
1087                ## User
1088
1089                Hello
1090            "}
1091        );
1092    });
1093
1094    fake_model.send_last_completion_stream_text_chunk("Hey!");
1095    cx.run_until_parked();
1096    thread.read_with(cx, |thread, _| {
1097        assert_eq!(
1098            thread.to_markdown(),
1099            indoc! {"
1100                ## User
1101
1102                Hello
1103
1104                ## Assistant
1105
1106                Hey!
1107            "}
1108        );
1109    });
1110
1111    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1112    fake_model
1113        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1114    let events = events.collect::<Vec<_>>().await;
1115    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1116    thread.read_with(cx, |thread, _| {
1117        assert_eq!(thread.to_markdown(), "");
1118    });
1119}
1120
1121#[gpui::test]
1122async fn test_truncate(cx: &mut TestAppContext) {
1123    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1124    let fake_model = model.as_fake();
1125
1126    let message_id = UserMessageId::new();
1127    thread
1128        .update(cx, |thread, cx| {
1129            thread.send(message_id.clone(), ["Hello"], cx)
1130        })
1131        .unwrap();
1132    cx.run_until_parked();
1133    thread.read_with(cx, |thread, _| {
1134        assert_eq!(
1135            thread.to_markdown(),
1136            indoc! {"
1137                ## User
1138
1139                Hello
1140            "}
1141        );
1142    });
1143
1144    fake_model.send_last_completion_stream_text_chunk("Hey!");
1145    cx.run_until_parked();
1146    thread.read_with(cx, |thread, _| {
1147        assert_eq!(
1148            thread.to_markdown(),
1149            indoc! {"
1150                ## User
1151
1152                Hello
1153
1154                ## Assistant
1155
1156                Hey!
1157            "}
1158        );
1159    });
1160
1161    thread
1162        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1163        .unwrap();
1164    cx.run_until_parked();
1165    thread.read_with(cx, |thread, _| {
1166        assert_eq!(thread.to_markdown(), "");
1167    });
1168
1169    // Ensure we can still send a new message after truncation.
1170    thread
1171        .update(cx, |thread, cx| {
1172            thread.send(UserMessageId::new(), ["Hi"], cx)
1173        })
1174        .unwrap();
1175    thread.update(cx, |thread, _cx| {
1176        assert_eq!(
1177            thread.to_markdown(),
1178            indoc! {"
1179                ## User
1180
1181                Hi
1182            "}
1183        );
1184    });
1185    cx.run_until_parked();
1186    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1187    cx.run_until_parked();
1188    thread.read_with(cx, |thread, _| {
1189        assert_eq!(
1190            thread.to_markdown(),
1191            indoc! {"
1192                ## User
1193
1194                Hi
1195
1196                ## Assistant
1197
1198                Ahoy!
1199            "}
1200        );
1201    });
1202}
1203
1204#[gpui::test]
1205async fn test_title_generation(cx: &mut TestAppContext) {
1206    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1207    let fake_model = model.as_fake();
1208
1209    let summary_model = Arc::new(FakeLanguageModel::default());
1210    thread.update(cx, |thread, cx| {
1211        thread.set_summarization_model(Some(summary_model.clone()), cx)
1212    });
1213
1214    let send = thread
1215        .update(cx, |thread, cx| {
1216            thread.send(UserMessageId::new(), ["Hello"], cx)
1217        })
1218        .unwrap();
1219    cx.run_until_parked();
1220
1221    fake_model.send_last_completion_stream_text_chunk("Hey!");
1222    fake_model.end_last_completion_stream();
1223    cx.run_until_parked();
1224    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1225
1226    // Ensure the summary model has been invoked to generate a title.
1227    summary_model.send_last_completion_stream_text_chunk("Hello ");
1228    summary_model.send_last_completion_stream_text_chunk("world\nG");
1229    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1230    summary_model.end_last_completion_stream();
1231    send.collect::<Vec<_>>().await;
1232    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1233
1234    // Send another message, ensuring no title is generated this time.
1235    let send = thread
1236        .update(cx, |thread, cx| {
1237            thread.send(UserMessageId::new(), ["Hello again"], cx)
1238        })
1239        .unwrap();
1240    cx.run_until_parked();
1241    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1242    fake_model.end_last_completion_stream();
1243    cx.run_until_parked();
1244    assert_eq!(summary_model.pending_completions(), Vec::new());
1245    send.collect::<Vec<_>>().await;
1246    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1247}
1248
1249#[gpui::test]
1250async fn test_agent_connection(cx: &mut TestAppContext) {
1251    cx.update(settings::init);
1252    let templates = Templates::new();
1253
1254    // Initialize language model system with test provider
1255    cx.update(|cx| {
1256        gpui_tokio::init(cx);
1257        client::init_settings(cx);
1258
1259        let http_client = FakeHttpClient::with_404_response();
1260        let clock = Arc::new(clock::FakeSystemClock::new());
1261        let client = Client::new(clock, http_client, cx);
1262        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1263        language_model::init(client.clone(), cx);
1264        language_models::init(user_store.clone(), client.clone(), cx);
1265        Project::init_settings(cx);
1266        LanguageModelRegistry::test(cx);
1267        agent_settings::init(cx);
1268    });
1269    cx.executor().forbid_parking();
1270
1271    // Create a project for new_thread
1272    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1273    fake_fs.insert_tree(path!("/test"), json!({})).await;
1274    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1275    let cwd = Path::new("/test");
1276    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1277    let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
1278
1279    // Create agent and connection
1280    let agent = NativeAgent::new(
1281        project.clone(),
1282        history_store,
1283        templates.clone(),
1284        None,
1285        fake_fs.clone(),
1286        &mut cx.to_async(),
1287    )
1288    .await
1289    .unwrap();
1290    let connection = NativeAgentConnection(agent.clone());
1291
1292    // Test model_selector returns Some
1293    let selector_opt = connection.model_selector();
1294    assert!(
1295        selector_opt.is_some(),
1296        "agent2 should always support ModelSelector"
1297    );
1298    let selector = selector_opt.unwrap();
1299
1300    // Test list_models
1301    let listed_models = cx
1302        .update(|cx| selector.list_models(cx))
1303        .await
1304        .expect("list_models should succeed");
1305    let AgentModelList::Grouped(listed_models) = listed_models else {
1306        panic!("Unexpected model list type");
1307    };
1308    assert!(!listed_models.is_empty(), "should have at least one model");
1309    assert_eq!(
1310        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1311        "fake/fake"
1312    );
1313
1314    // Create a thread using new_thread
1315    let connection_rc = Rc::new(connection.clone());
1316    let acp_thread = cx
1317        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1318        .await
1319        .expect("new_thread should succeed");
1320
1321    // Get the session_id from the AcpThread
1322    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1323
1324    // Test selected_model returns the default
1325    let model = cx
1326        .update(|cx| selector.selected_model(&session_id, cx))
1327        .await
1328        .expect("selected_model should succeed");
1329    let model = cx
1330        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1331        .unwrap();
1332    let model = model.as_fake();
1333    assert_eq!(model.id().0, "fake", "should return default model");
1334
1335    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1336    cx.run_until_parked();
1337    model.send_last_completion_stream_text_chunk("def");
1338    cx.run_until_parked();
1339    acp_thread.read_with(cx, |thread, cx| {
1340        assert_eq!(
1341            thread.to_markdown(cx),
1342            indoc! {"
1343                ## User
1344
1345                abc
1346
1347                ## Assistant
1348
1349                def
1350
1351            "}
1352        )
1353    });
1354
1355    // Test cancel
1356    cx.update(|cx| connection.cancel(&session_id, cx));
1357    request.await.expect("prompt should fail gracefully");
1358
1359    // Ensure that dropping the ACP thread causes the native thread to be
1360    // dropped as well.
1361    cx.update(|_| drop(acp_thread));
1362    let result = cx
1363        .update(|cx| {
1364            connection.prompt(
1365                Some(acp_thread::UserMessageId::new()),
1366                acp::PromptRequest {
1367                    session_id: session_id.clone(),
1368                    prompt: vec!["ghi".into()],
1369                },
1370                cx,
1371            )
1372        })
1373        .await;
1374    assert_eq!(
1375        result.as_ref().unwrap_err().to_string(),
1376        "Session not found",
1377        "unexpected result: {:?}",
1378        result
1379    );
1380}
1381
1382#[gpui::test]
1383async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1384    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1385    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1386    let fake_model = model.as_fake();
1387
1388    let mut events = thread
1389        .update(cx, |thread, cx| {
1390            thread.send(UserMessageId::new(), ["Think"], cx)
1391        })
1392        .unwrap();
1393    cx.run_until_parked();
1394
1395    // Simulate streaming partial input.
1396    let input = json!({});
1397    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1398        LanguageModelToolUse {
1399            id: "1".into(),
1400            name: ThinkingTool.name().into(),
1401            raw_input: input.to_string(),
1402            input,
1403            is_input_complete: false,
1404        },
1405    ));
1406
1407    // Input streaming completed
1408    let input = json!({ "content": "Thinking hard!" });
1409    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1410        LanguageModelToolUse {
1411            id: "1".into(),
1412            name: "thinking".into(),
1413            raw_input: input.to_string(),
1414            input,
1415            is_input_complete: true,
1416        },
1417    ));
1418    fake_model.end_last_completion_stream();
1419    cx.run_until_parked();
1420
1421    let tool_call = expect_tool_call(&mut events).await;
1422    assert_eq!(
1423        tool_call,
1424        acp::ToolCall {
1425            id: acp::ToolCallId("1".into()),
1426            title: "Thinking".into(),
1427            kind: acp::ToolKind::Think,
1428            status: acp::ToolCallStatus::Pending,
1429            content: vec![],
1430            locations: vec![],
1431            raw_input: Some(json!({})),
1432            raw_output: None,
1433        }
1434    );
1435    let update = expect_tool_call_update_fields(&mut events).await;
1436    assert_eq!(
1437        update,
1438        acp::ToolCallUpdate {
1439            id: acp::ToolCallId("1".into()),
1440            fields: acp::ToolCallUpdateFields {
1441                title: Some("Thinking".into()),
1442                kind: Some(acp::ToolKind::Think),
1443                raw_input: Some(json!({ "content": "Thinking hard!" })),
1444                ..Default::default()
1445            },
1446        }
1447    );
1448    let update = expect_tool_call_update_fields(&mut events).await;
1449    assert_eq!(
1450        update,
1451        acp::ToolCallUpdate {
1452            id: acp::ToolCallId("1".into()),
1453            fields: acp::ToolCallUpdateFields {
1454                status: Some(acp::ToolCallStatus::InProgress),
1455                ..Default::default()
1456            },
1457        }
1458    );
1459    let update = expect_tool_call_update_fields(&mut events).await;
1460    assert_eq!(
1461        update,
1462        acp::ToolCallUpdate {
1463            id: acp::ToolCallId("1".into()),
1464            fields: acp::ToolCallUpdateFields {
1465                content: Some(vec!["Thinking hard!".into()]),
1466                ..Default::default()
1467            },
1468        }
1469    );
1470    let update = expect_tool_call_update_fields(&mut events).await;
1471    assert_eq!(
1472        update,
1473        acp::ToolCallUpdate {
1474            id: acp::ToolCallId("1".into()),
1475            fields: acp::ToolCallUpdateFields {
1476                status: Some(acp::ToolCallStatus::Completed),
1477                raw_output: Some("Finished thinking.".into()),
1478                ..Default::default()
1479            },
1480        }
1481    );
1482}
1483
1484#[gpui::test]
1485async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1486    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1487    let fake_model = model.as_fake();
1488
1489    let mut events = thread
1490        .update(cx, |thread, cx| {
1491            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1492            thread.send(UserMessageId::new(), ["Hello!"], cx)
1493        })
1494        .unwrap();
1495    cx.run_until_parked();
1496
1497    fake_model.send_last_completion_stream_text_chunk("Hey!");
1498    fake_model.end_last_completion_stream();
1499
1500    let mut retry_events = Vec::new();
1501    while let Some(Ok(event)) = events.next().await {
1502        match event {
1503            ThreadEvent::Retry(retry_status) => {
1504                retry_events.push(retry_status);
1505            }
1506            ThreadEvent::Stop(..) => break,
1507            _ => {}
1508        }
1509    }
1510
1511    assert_eq!(retry_events.len(), 0);
1512    thread.read_with(cx, |thread, _cx| {
1513        assert_eq!(
1514            thread.to_markdown(),
1515            indoc! {"
1516                ## User
1517
1518                Hello!
1519
1520                ## Assistant
1521
1522                Hey!
1523            "}
1524        )
1525    });
1526}
1527
1528#[gpui::test]
1529async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1530    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1531    let fake_model = model.as_fake();
1532
1533    let mut events = thread
1534        .update(cx, |thread, cx| {
1535            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1536            thread.send(UserMessageId::new(), ["Hello!"], cx)
1537        })
1538        .unwrap();
1539    cx.run_until_parked();
1540
1541    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1542        provider: LanguageModelProviderName::new("Anthropic"),
1543        retry_after: Some(Duration::from_secs(3)),
1544    });
1545    fake_model.end_last_completion_stream();
1546
1547    cx.executor().advance_clock(Duration::from_secs(3));
1548    cx.run_until_parked();
1549
1550    fake_model.send_last_completion_stream_text_chunk("Hey!");
1551    fake_model.end_last_completion_stream();
1552
1553    let mut retry_events = Vec::new();
1554    while let Some(Ok(event)) = events.next().await {
1555        match event {
1556            ThreadEvent::Retry(retry_status) => {
1557                retry_events.push(retry_status);
1558            }
1559            ThreadEvent::Stop(..) => break,
1560            _ => {}
1561        }
1562    }
1563
1564    assert_eq!(retry_events.len(), 1);
1565    assert!(matches!(
1566        retry_events[0],
1567        acp_thread::RetryStatus { attempt: 1, .. }
1568    ));
1569    thread.read_with(cx, |thread, _cx| {
1570        assert_eq!(
1571            thread.to_markdown(),
1572            indoc! {"
1573                ## User
1574
1575                Hello!
1576
1577                ## Assistant
1578
1579                Hey!
1580            "}
1581        )
1582    });
1583}
1584
1585#[gpui::test]
1586async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1587    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1588    let fake_model = model.as_fake();
1589
1590    let mut events = thread
1591        .update(cx, |thread, cx| {
1592            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1593            thread.send(UserMessageId::new(), ["Hello!"], cx)
1594        })
1595        .unwrap();
1596    cx.run_until_parked();
1597
1598    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1599        fake_model.send_last_completion_stream_error(
1600            LanguageModelCompletionError::ServerOverloaded {
1601                provider: LanguageModelProviderName::new("Anthropic"),
1602                retry_after: Some(Duration::from_secs(3)),
1603            },
1604        );
1605        fake_model.end_last_completion_stream();
1606        cx.executor().advance_clock(Duration::from_secs(3));
1607        cx.run_until_parked();
1608    }
1609
1610    let mut errors = Vec::new();
1611    let mut retry_events = Vec::new();
1612    while let Some(event) = events.next().await {
1613        match event {
1614            Ok(ThreadEvent::Retry(retry_status)) => {
1615                retry_events.push(retry_status);
1616            }
1617            Ok(ThreadEvent::Stop(..)) => break,
1618            Err(error) => errors.push(error),
1619            _ => {}
1620        }
1621    }
1622
1623    assert_eq!(
1624        retry_events.len(),
1625        crate::thread::MAX_RETRY_ATTEMPTS as usize
1626    );
1627    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1628        assert_eq!(retry_events[i].attempt, i + 1);
1629    }
1630    assert_eq!(errors.len(), 1);
1631    let error = errors[0]
1632        .downcast_ref::<LanguageModelCompletionError>()
1633        .unwrap();
1634    assert!(matches!(
1635        error,
1636        LanguageModelCompletionError::ServerOverloaded { .. }
1637    ));
1638}
1639
1640/// Filters out the stop events for asserting against in tests
1641fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1642    result_events
1643        .into_iter()
1644        .filter_map(|event| match event.unwrap() {
1645            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1646            _ => None,
1647        })
1648        .collect()
1649}
1650
1651struct ThreadTest {
1652    model: Arc<dyn LanguageModel>,
1653    thread: Entity<Thread>,
1654    project_context: Entity<ProjectContext>,
1655    fs: Arc<FakeFs>,
1656}
1657
1658enum TestModel {
1659    Sonnet4,
1660    Sonnet4Thinking,
1661    Fake,
1662}
1663
1664impl TestModel {
1665    fn id(&self) -> LanguageModelId {
1666        match self {
1667            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1668            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1669            TestModel::Fake => unreachable!(),
1670        }
1671    }
1672}
1673
1674async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1675    cx.executor().allow_parking();
1676
1677    let fs = FakeFs::new(cx.background_executor.clone());
1678    fs.create_dir(paths::settings_file().parent().unwrap())
1679        .await
1680        .unwrap();
1681    fs.insert_file(
1682        paths::settings_file(),
1683        json!({
1684            "agent": {
1685                "default_profile": "test-profile",
1686                "profiles": {
1687                    "test-profile": {
1688                        "name": "Test Profile",
1689                        "tools": {
1690                            EchoTool.name(): true,
1691                            DelayTool.name(): true,
1692                            WordListTool.name(): true,
1693                            ToolRequiringPermission.name(): true,
1694                            InfiniteTool.name(): true,
1695                        }
1696                    }
1697                }
1698            }
1699        })
1700        .to_string()
1701        .into_bytes(),
1702    )
1703    .await;
1704
1705    cx.update(|cx| {
1706        settings::init(cx);
1707        Project::init_settings(cx);
1708        agent_settings::init(cx);
1709        gpui_tokio::init(cx);
1710        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1711        cx.set_http_client(Arc::new(http_client));
1712
1713        client::init_settings(cx);
1714        let client = Client::production(cx);
1715        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1716        language_model::init(client.clone(), cx);
1717        language_models::init(user_store.clone(), client.clone(), cx);
1718
1719        watch_settings(fs.clone(), cx);
1720    });
1721
1722    let templates = Templates::new();
1723
1724    fs.insert_tree(path!("/test"), json!({})).await;
1725    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1726
1727    let model = cx
1728        .update(|cx| {
1729            if let TestModel::Fake = model {
1730                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1731            } else {
1732                let model_id = model.id();
1733                let models = LanguageModelRegistry::read_global(cx);
1734                let model = models
1735                    .available_models(cx)
1736                    .find(|model| model.id() == model_id)
1737                    .unwrap();
1738
1739                let provider = models.provider(&model.provider_id()).unwrap();
1740                let authenticated = provider.authenticate(cx);
1741
1742                cx.spawn(async move |_cx| {
1743                    authenticated.await.unwrap();
1744                    model
1745                })
1746            }
1747        })
1748        .await;
1749
1750    let project_context = cx.new(|_cx| ProjectContext::default());
1751    let context_server_registry =
1752        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1753    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1754    let thread = cx.new(|cx| {
1755        Thread::new(
1756            project,
1757            project_context.clone(),
1758            context_server_registry,
1759            action_log,
1760            templates,
1761            Some(model.clone()),
1762            cx,
1763        )
1764    });
1765    ThreadTest {
1766        model,
1767        thread,
1768        project_context,
1769        fs,
1770    }
1771}
1772
1773#[cfg(test)]
1774#[ctor::ctor]
1775fn init_logger() {
1776    if std::env::var("RUST_LOG").is_ok() {
1777        env_logger::init();
1778    }
1779}
1780
1781fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1782    let fs = fs.clone();
1783    cx.spawn({
1784        async move |cx| {
1785            let mut new_settings_content_rx = settings::watch_config_file(
1786                cx.background_executor(),
1787                fs,
1788                paths::settings_file().clone(),
1789            );
1790
1791            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1792                cx.update(|cx| {
1793                    SettingsStore::update_global(cx, |settings, cx| {
1794                        settings.set_user_settings(&new_settings_content, cx)
1795                    })
1796                })
1797                .ok();
1798            }
1799        }
1800    })
1801    .detach();
1802}