mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  17    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use pretty_assertions::assert_eq;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use reqwest_client::ReqwestClient;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .unwrap()
  44        .collect()
  45        .await;
  46    thread.update(cx, |thread, _cx| {
  47        assert_eq!(
  48            thread.last_message().unwrap().to_markdown(),
  49            indoc! {"
  50                ## Assistant
  51
  52                Hello
  53            "}
  54        )
  55    });
  56    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  57}
  58
  59#[gpui::test]
  60#[ignore = "can't run on CI yet"]
  61async fn test_thinking(cx: &mut TestAppContext) {
  62    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  63
  64    let events = thread
  65        .update(cx, |thread, cx| {
  66            thread.send(
  67                UserMessageId::new(),
  68                [indoc! {"
  69                    Testing:
  70
  71                    Generate a thinking step where you just think the word 'Think',
  72                    and have your final answer be 'Hello'
  73                "}],
  74                cx,
  75            )
  76        })
  77        .unwrap()
  78        .collect()
  79        .await;
  80    thread.update(cx, |thread, _cx| {
  81        assert_eq!(
  82            thread.last_message().unwrap().to_markdown(),
  83            indoc! {"
  84                ## Assistant
  85
  86                <think>Think</think>
  87                Hello
  88            "}
  89        )
  90    });
  91    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  92}
  93
  94#[gpui::test]
  95async fn test_system_prompt(cx: &mut TestAppContext) {
  96    let ThreadTest {
  97        model,
  98        thread,
  99        project_context,
 100        ..
 101    } = setup(cx, TestModel::Fake).await;
 102    let fake_model = model.as_fake();
 103
 104    project_context.update(cx, |project_context, _cx| {
 105        project_context.shell = "test-shell".into()
 106    });
 107    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 108    thread
 109        .update(cx, |thread, cx| {
 110            thread.send(UserMessageId::new(), ["abc"], cx)
 111        })
 112        .unwrap();
 113    cx.run_until_parked();
 114    let mut pending_completions = fake_model.pending_completions();
 115    assert_eq!(
 116        pending_completions.len(),
 117        1,
 118        "unexpected pending completions: {:?}",
 119        pending_completions
 120    );
 121
 122    let pending_completion = pending_completions.pop().unwrap();
 123    assert_eq!(pending_completion.messages[0].role, Role::System);
 124
 125    let system_message = &pending_completion.messages[0];
 126    let system_prompt = system_message.content[0].to_str().unwrap();
 127    assert!(
 128        system_prompt.contains("test-shell"),
 129        "unexpected system message: {:?}",
 130        system_message
 131    );
 132    assert!(
 133        system_prompt.contains("## Fixing Diagnostics"),
 134        "unexpected system message: {:?}",
 135        system_message
 136    );
 137}
 138
 139#[gpui::test]
 140async fn test_prompt_caching(cx: &mut TestAppContext) {
 141    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 142    let fake_model = model.as_fake();
 143
 144    // Send initial user message and verify it's cached
 145    thread
 146        .update(cx, |thread, cx| {
 147            thread.send(UserMessageId::new(), ["Message 1"], cx)
 148        })
 149        .unwrap();
 150    cx.run_until_parked();
 151
 152    let completion = fake_model.pending_completions().pop().unwrap();
 153    assert_eq!(
 154        completion.messages[1..],
 155        vec![LanguageModelRequestMessage {
 156            role: Role::User,
 157            content: vec!["Message 1".into()],
 158            cache: true
 159        }]
 160    );
 161    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 162        "Response to Message 1".into(),
 163    ));
 164    fake_model.end_last_completion_stream();
 165    cx.run_until_parked();
 166
 167    // Send another user message and verify only the latest is cached
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["Message 2"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174
 175    let completion = fake_model.pending_completions().pop().unwrap();
 176    assert_eq!(
 177        completion.messages[1..],
 178        vec![
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 1".into()],
 182                cache: false
 183            },
 184            LanguageModelRequestMessage {
 185                role: Role::Assistant,
 186                content: vec!["Response to Message 1".into()],
 187                cache: false
 188            },
 189            LanguageModelRequestMessage {
 190                role: Role::User,
 191                content: vec!["Message 2".into()],
 192                cache: true
 193            }
 194        ]
 195    );
 196    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 197        "Response to Message 2".into(),
 198    ));
 199    fake_model.end_last_completion_stream();
 200    cx.run_until_parked();
 201
 202    // Simulate a tool call and verify that the latest tool result is cached
 203    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 204    thread
 205        .update(cx, |thread, cx| {
 206            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 207        })
 208        .unwrap();
 209    cx.run_until_parked();
 210
 211    let tool_use = LanguageModelToolUse {
 212        id: "tool_1".into(),
 213        name: EchoTool.name().into(),
 214        raw_input: json!({"text": "test"}).to_string(),
 215        input: json!({"text": "test"}),
 216        is_input_complete: true,
 217    };
 218    fake_model
 219        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 220    fake_model.end_last_completion_stream();
 221    cx.run_until_parked();
 222
 223    let completion = fake_model.pending_completions().pop().unwrap();
 224    let tool_result = LanguageModelToolResult {
 225        tool_use_id: "tool_1".into(),
 226        tool_name: EchoTool.name().into(),
 227        is_error: false,
 228        content: "test".into(),
 229        output: Some("test".into()),
 230    };
 231    assert_eq!(
 232        completion.messages[1..],
 233        vec![
 234            LanguageModelRequestMessage {
 235                role: Role::User,
 236                content: vec!["Message 1".into()],
 237                cache: false
 238            },
 239            LanguageModelRequestMessage {
 240                role: Role::Assistant,
 241                content: vec!["Response to Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::User,
 246                content: vec!["Message 2".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::Assistant,
 251                content: vec!["Response to Message 2".into()],
 252                cache: false
 253            },
 254            LanguageModelRequestMessage {
 255                role: Role::User,
 256                content: vec!["Use the echo tool".into()],
 257                cache: false
 258            },
 259            LanguageModelRequestMessage {
 260                role: Role::Assistant,
 261                content: vec![MessageContent::ToolUse(tool_use)],
 262                cache: false
 263            },
 264            LanguageModelRequestMessage {
 265                role: Role::User,
 266                content: vec![MessageContent::ToolResult(tool_result)],
 267                cache: true
 268            }
 269        ]
 270    );
 271}
 272
 273#[gpui::test]
 274#[ignore = "can't run on CI yet"]
 275async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 276    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 277
 278    // Test a tool call that's likely to complete *before* streaming stops.
 279    let events = thread
 280        .update(cx, |thread, cx| {
 281            thread.add_tool(EchoTool);
 282            thread.send(
 283                UserMessageId::new(),
 284                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 285                cx,
 286            )
 287        })
 288        .unwrap()
 289        .collect()
 290        .await;
 291    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 292
 293    // Test a tool calls that's likely to complete *after* streaming stops.
 294    let events = thread
 295        .update(cx, |thread, cx| {
 296            thread.remove_tool(&AgentTool::name(&EchoTool));
 297            thread.add_tool(DelayTool);
 298            thread.send(
 299                UserMessageId::new(),
 300                [
 301                    "Now call the delay tool with 200ms.",
 302                    "When the timer goes off, then you echo the output of the tool.",
 303                ],
 304                cx,
 305            )
 306        })
 307        .unwrap()
 308        .collect()
 309        .await;
 310    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 311    thread.update(cx, |thread, _cx| {
 312        assert!(
 313            thread
 314                .last_message()
 315                .unwrap()
 316                .as_agent_message()
 317                .unwrap()
 318                .content
 319                .iter()
 320                .any(|content| {
 321                    if let AgentMessageContent::Text(text) = content {
 322                        text.contains("Ding")
 323                    } else {
 324                        false
 325                    }
 326                }),
 327            "{}",
 328            thread.to_markdown()
 329        );
 330    });
 331}
 332
 333#[gpui::test]
 334#[ignore = "can't run on CI yet"]
 335async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let mut events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(WordListTool);
 342            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 343        })
 344        .unwrap();
 345
 346    let mut saw_partial_tool_use = false;
 347    while let Some(event) = events.next().await {
 348        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 349            thread.update(cx, |thread, _cx| {
 350                // Look for a tool use in the thread's last message
 351                let message = thread.last_message().unwrap();
 352                let agent_message = message.as_agent_message().unwrap();
 353                let last_content = agent_message.content.last().unwrap();
 354                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 355                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 356                    if tool_call.status == acp::ToolCallStatus::Pending {
 357                        if !last_tool_use.is_input_complete
 358                            && last_tool_use.input.get("g").is_none()
 359                        {
 360                            saw_partial_tool_use = true;
 361                        }
 362                    } else {
 363                        last_tool_use
 364                            .input
 365                            .get("a")
 366                            .expect("'a' has streamed because input is now complete");
 367                        last_tool_use
 368                            .input
 369                            .get("g")
 370                            .expect("'g' has streamed because input is now complete");
 371                    }
 372                } else {
 373                    panic!("last content should be a tool use");
 374                }
 375            });
 376        }
 377    }
 378
 379    assert!(
 380        saw_partial_tool_use,
 381        "should see at least one partially streamed tool use in the history"
 382    );
 383}
 384
 385#[gpui::test]
 386async fn test_tool_authorization(cx: &mut TestAppContext) {
 387    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 388    let fake_model = model.as_fake();
 389
 390    let mut events = thread
 391        .update(cx, |thread, cx| {
 392            thread.add_tool(ToolRequiringPermission);
 393            thread.send(UserMessageId::new(), ["abc"], cx)
 394        })
 395        .unwrap();
 396    cx.run_until_parked();
 397    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 398        LanguageModelToolUse {
 399            id: "tool_id_1".into(),
 400            name: ToolRequiringPermission.name().into(),
 401            raw_input: "{}".into(),
 402            input: json!({}),
 403            is_input_complete: true,
 404        },
 405    ));
 406    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 407        LanguageModelToolUse {
 408            id: "tool_id_2".into(),
 409            name: ToolRequiringPermission.name().into(),
 410            raw_input: "{}".into(),
 411            input: json!({}),
 412            is_input_complete: true,
 413        },
 414    ));
 415    fake_model.end_last_completion_stream();
 416    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 417    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 418
 419    // Approve the first
 420    tool_call_auth_1
 421        .response
 422        .send(tool_call_auth_1.options[1].id.clone())
 423        .unwrap();
 424    cx.run_until_parked();
 425
 426    // Reject the second
 427    tool_call_auth_2
 428        .response
 429        .send(tool_call_auth_1.options[2].id.clone())
 430        .unwrap();
 431    cx.run_until_parked();
 432
 433    let completion = fake_model.pending_completions().pop().unwrap();
 434    let message = completion.messages.last().unwrap();
 435    assert_eq!(
 436        message.content,
 437        vec![
 438            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 439                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 440                tool_name: ToolRequiringPermission.name().into(),
 441                is_error: false,
 442                content: "Allowed".into(),
 443                output: Some("Allowed".into())
 444            }),
 445            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 446                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 447                tool_name: ToolRequiringPermission.name().into(),
 448                is_error: true,
 449                content: "Permission to run tool denied by user".into(),
 450                output: None
 451            })
 452        ]
 453    );
 454
 455    // Simulate yet another tool call.
 456    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 457        LanguageModelToolUse {
 458            id: "tool_id_3".into(),
 459            name: ToolRequiringPermission.name().into(),
 460            raw_input: "{}".into(),
 461            input: json!({}),
 462            is_input_complete: true,
 463        },
 464    ));
 465    fake_model.end_last_completion_stream();
 466
 467    // Respond by always allowing tools.
 468    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 469    tool_call_auth_3
 470        .response
 471        .send(tool_call_auth_3.options[0].id.clone())
 472        .unwrap();
 473    cx.run_until_parked();
 474    let completion = fake_model.pending_completions().pop().unwrap();
 475    let message = completion.messages.last().unwrap();
 476    assert_eq!(
 477        message.content,
 478        vec![language_model::MessageContent::ToolResult(
 479            LanguageModelToolResult {
 480                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 481                tool_name: ToolRequiringPermission.name().into(),
 482                is_error: false,
 483                content: "Allowed".into(),
 484                output: Some("Allowed".into())
 485            }
 486        )]
 487    );
 488
 489    // Simulate a final tool call, ensuring we don't trigger authorization.
 490    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 491        LanguageModelToolUse {
 492            id: "tool_id_4".into(),
 493            name: ToolRequiringPermission.name().into(),
 494            raw_input: "{}".into(),
 495            input: json!({}),
 496            is_input_complete: true,
 497        },
 498    ));
 499    fake_model.end_last_completion_stream();
 500    cx.run_until_parked();
 501    let completion = fake_model.pending_completions().pop().unwrap();
 502    let message = completion.messages.last().unwrap();
 503    assert_eq!(
 504        message.content,
 505        vec![language_model::MessageContent::ToolResult(
 506            LanguageModelToolResult {
 507                tool_use_id: "tool_id_4".into(),
 508                tool_name: ToolRequiringPermission.name().into(),
 509                is_error: false,
 510                content: "Allowed".into(),
 511                output: Some("Allowed".into())
 512            }
 513        )]
 514    );
 515}
 516
 517#[gpui::test]
 518async fn test_tool_hallucination(cx: &mut TestAppContext) {
 519    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 520    let fake_model = model.as_fake();
 521
 522    let mut events = thread
 523        .update(cx, |thread, cx| {
 524            thread.send(UserMessageId::new(), ["abc"], cx)
 525        })
 526        .unwrap();
 527    cx.run_until_parked();
 528    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 529        LanguageModelToolUse {
 530            id: "tool_id_1".into(),
 531            name: "nonexistent_tool".into(),
 532            raw_input: "{}".into(),
 533            input: json!({}),
 534            is_input_complete: true,
 535        },
 536    ));
 537    fake_model.end_last_completion_stream();
 538
 539    let tool_call = expect_tool_call(&mut events).await;
 540    assert_eq!(tool_call.title, "nonexistent_tool");
 541    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 542    let update = expect_tool_call_update_fields(&mut events).await;
 543    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 544}
 545
 546#[gpui::test]
 547async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 548    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 549    let fake_model = model.as_fake();
 550
 551    let events = thread
 552        .update(cx, |thread, cx| {
 553            thread.add_tool(EchoTool);
 554            thread.send(UserMessageId::new(), ["abc"], cx)
 555        })
 556        .unwrap();
 557    cx.run_until_parked();
 558    let tool_use = LanguageModelToolUse {
 559        id: "tool_id_1".into(),
 560        name: EchoTool.name().into(),
 561        raw_input: "{}".into(),
 562        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 563        is_input_complete: true,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568
 569    cx.run_until_parked();
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_id_1".into(),
 573        tool_name: EchoTool.name().into(),
 574        is_error: false,
 575        content: "def".into(),
 576        output: Some("def".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["abc".into()],
 584                cache: false
 585            },
 586            LanguageModelRequestMessage {
 587                role: Role::Assistant,
 588                content: vec![MessageContent::ToolUse(tool_use.clone())],
 589                cache: false
 590            },
 591            LanguageModelRequestMessage {
 592                role: Role::User,
 593                content: vec![MessageContent::ToolResult(tool_result.clone())],
 594                cache: true
 595            },
 596        ]
 597    );
 598
 599    // Simulate reaching tool use limit.
 600    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 601        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 602    ));
 603    fake_model.end_last_completion_stream();
 604    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 605    assert!(
 606        last_event
 607            .unwrap_err()
 608            .is::<language_model::ToolUseLimitReachedError>()
 609    );
 610
 611    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 612    cx.run_until_parked();
 613    let completion = fake_model.pending_completions().pop().unwrap();
 614    assert_eq!(
 615        completion.messages[1..],
 616        vec![
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec!["abc".into()],
 620                cache: false
 621            },
 622            LanguageModelRequestMessage {
 623                role: Role::Assistant,
 624                content: vec![MessageContent::ToolUse(tool_use)],
 625                cache: false
 626            },
 627            LanguageModelRequestMessage {
 628                role: Role::User,
 629                content: vec![MessageContent::ToolResult(tool_result)],
 630                cache: false
 631            },
 632            LanguageModelRequestMessage {
 633                role: Role::User,
 634                content: vec!["Continue where you left off".into()],
 635                cache: true
 636            }
 637        ]
 638    );
 639
 640    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 641    fake_model.end_last_completion_stream();
 642    events.collect::<Vec<_>>().await;
 643    thread.read_with(cx, |thread, _cx| {
 644        assert_eq!(
 645            thread.last_message().unwrap().to_markdown(),
 646            indoc! {"
 647                ## Assistant
 648
 649                Done
 650            "}
 651        )
 652    });
 653
 654    // Ensure we error if calling resume when tool use limit was *not* reached.
 655    let error = thread
 656        .update(cx, |thread, cx| thread.resume(cx))
 657        .unwrap_err();
 658    assert_eq!(
 659        error.to_string(),
 660        "can only resume after tool use limit is reached"
 661    )
 662}
 663
 664#[gpui::test]
 665async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 666    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 667    let fake_model = model.as_fake();
 668
 669    let events = thread
 670        .update(cx, |thread, cx| {
 671            thread.add_tool(EchoTool);
 672            thread.send(UserMessageId::new(), ["abc"], cx)
 673        })
 674        .unwrap();
 675    cx.run_until_parked();
 676
 677    let tool_use = LanguageModelToolUse {
 678        id: "tool_id_1".into(),
 679        name: EchoTool.name().into(),
 680        raw_input: "{}".into(),
 681        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 682        is_input_complete: true,
 683    };
 684    let tool_result = LanguageModelToolResult {
 685        tool_use_id: "tool_id_1".into(),
 686        tool_name: EchoTool.name().into(),
 687        is_error: false,
 688        content: "def".into(),
 689        output: Some("def".into()),
 690    };
 691    fake_model
 692        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 693    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 694        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 695    ));
 696    fake_model.end_last_completion_stream();
 697    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 698    assert!(
 699        last_event
 700            .unwrap_err()
 701            .is::<language_model::ToolUseLimitReachedError>()
 702    );
 703
 704    thread
 705        .update(cx, |thread, cx| {
 706            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 707        })
 708        .unwrap();
 709    cx.run_until_parked();
 710    let completion = fake_model.pending_completions().pop().unwrap();
 711    assert_eq!(
 712        completion.messages[1..],
 713        vec![
 714            LanguageModelRequestMessage {
 715                role: Role::User,
 716                content: vec!["abc".into()],
 717                cache: false
 718            },
 719            LanguageModelRequestMessage {
 720                role: Role::Assistant,
 721                content: vec![MessageContent::ToolUse(tool_use)],
 722                cache: false
 723            },
 724            LanguageModelRequestMessage {
 725                role: Role::User,
 726                content: vec![MessageContent::ToolResult(tool_result)],
 727                cache: false
 728            },
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["ghi".into()],
 732                cache: true
 733            }
 734        ]
 735    );
 736}
 737
 738async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 739    let event = events
 740        .next()
 741        .await
 742        .expect("no tool call authorization event received")
 743        .unwrap();
 744    match event {
 745        ThreadEvent::ToolCall(tool_call) => return tool_call,
 746        event => {
 747            panic!("Unexpected event {event:?}");
 748        }
 749    }
 750}
 751
 752async fn expect_tool_call_update_fields(
 753    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 754) -> acp::ToolCallUpdate {
 755    let event = events
 756        .next()
 757        .await
 758        .expect("no tool call authorization event received")
 759        .unwrap();
 760    match event {
 761        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 762            return update;
 763        }
 764        event => {
 765            panic!("Unexpected event {event:?}");
 766        }
 767    }
 768}
 769
 770async fn next_tool_call_authorization(
 771    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 772) -> ToolCallAuthorization {
 773    loop {
 774        let event = events
 775            .next()
 776            .await
 777            .expect("no tool call authorization event received")
 778            .unwrap();
 779        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 780            let permission_kinds = tool_call_authorization
 781                .options
 782                .iter()
 783                .map(|o| o.kind)
 784                .collect::<Vec<_>>();
 785            assert_eq!(
 786                permission_kinds,
 787                vec![
 788                    acp::PermissionOptionKind::AllowAlways,
 789                    acp::PermissionOptionKind::AllowOnce,
 790                    acp::PermissionOptionKind::RejectOnce,
 791                ]
 792            );
 793            return tool_call_authorization;
 794        }
 795    }
 796}
 797
 798#[gpui::test]
 799#[ignore = "can't run on CI yet"]
 800async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 801    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 802
 803    // Test concurrent tool calls with different delay times
 804    let events = thread
 805        .update(cx, |thread, cx| {
 806            thread.add_tool(DelayTool);
 807            thread.send(
 808                UserMessageId::new(),
 809                [
 810                    "Call the delay tool twice in the same message.",
 811                    "Once with 100ms. Once with 300ms.",
 812                    "When both timers are complete, describe the outputs.",
 813                ],
 814                cx,
 815            )
 816        })
 817        .unwrap()
 818        .collect()
 819        .await;
 820
 821    let stop_reasons = stop_events(events);
 822    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 823
 824    thread.update(cx, |thread, _cx| {
 825        let last_message = thread.last_message().unwrap();
 826        let agent_message = last_message.as_agent_message().unwrap();
 827        let text = agent_message
 828            .content
 829            .iter()
 830            .filter_map(|content| {
 831                if let AgentMessageContent::Text(text) = content {
 832                    Some(text.as_str())
 833                } else {
 834                    None
 835                }
 836            })
 837            .collect::<String>();
 838
 839        assert!(text.contains("Ding"));
 840    });
 841}
 842
 843#[gpui::test]
 844async fn test_profiles(cx: &mut TestAppContext) {
 845    let ThreadTest {
 846        model, thread, fs, ..
 847    } = setup(cx, TestModel::Fake).await;
 848    let fake_model = model.as_fake();
 849
 850    thread.update(cx, |thread, _cx| {
 851        thread.add_tool(DelayTool);
 852        thread.add_tool(EchoTool);
 853        thread.add_tool(InfiniteTool);
 854    });
 855
 856    // Override profiles and wait for settings to be loaded.
 857    fs.insert_file(
 858        paths::settings_file(),
 859        json!({
 860            "agent": {
 861                "profiles": {
 862                    "test-1": {
 863                        "name": "Test Profile 1",
 864                        "tools": {
 865                            EchoTool.name(): true,
 866                            DelayTool.name(): true,
 867                        }
 868                    },
 869                    "test-2": {
 870                        "name": "Test Profile 2",
 871                        "tools": {
 872                            InfiniteTool.name(): true,
 873                        }
 874                    }
 875                }
 876            }
 877        })
 878        .to_string()
 879        .into_bytes(),
 880    )
 881    .await;
 882    cx.run_until_parked();
 883
 884    // Test that test-1 profile (default) has echo and delay tools
 885    thread
 886        .update(cx, |thread, cx| {
 887            thread.set_profile(AgentProfileId("test-1".into()));
 888            thread.send(UserMessageId::new(), ["test"], cx)
 889        })
 890        .unwrap();
 891    cx.run_until_parked();
 892
 893    let mut pending_completions = fake_model.pending_completions();
 894    assert_eq!(pending_completions.len(), 1);
 895    let completion = pending_completions.pop().unwrap();
 896    let tool_names: Vec<String> = completion
 897        .tools
 898        .iter()
 899        .map(|tool| tool.name.clone())
 900        .collect();
 901    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 902    fake_model.end_last_completion_stream();
 903
 904    // Switch to test-2 profile, and verify that it has only the infinite tool.
 905    thread
 906        .update(cx, |thread, cx| {
 907            thread.set_profile(AgentProfileId("test-2".into()));
 908            thread.send(UserMessageId::new(), ["test2"], cx)
 909        })
 910        .unwrap();
 911    cx.run_until_parked();
 912    let mut pending_completions = fake_model.pending_completions();
 913    assert_eq!(pending_completions.len(), 1);
 914    let completion = pending_completions.pop().unwrap();
 915    let tool_names: Vec<String> = completion
 916        .tools
 917        .iter()
 918        .map(|tool| tool.name.clone())
 919        .collect();
 920    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 921}
 922
 923#[gpui::test]
 924#[ignore = "can't run on CI yet"]
 925async fn test_cancellation(cx: &mut TestAppContext) {
 926    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 927
 928    let mut events = thread
 929        .update(cx, |thread, cx| {
 930            thread.add_tool(InfiniteTool);
 931            thread.add_tool(EchoTool);
 932            thread.send(
 933                UserMessageId::new(),
 934                ["Call the echo tool, then call the infinite tool, then explain their output"],
 935                cx,
 936            )
 937        })
 938        .unwrap();
 939
 940    // Wait until both tools are called.
 941    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 942    let mut echo_id = None;
 943    let mut echo_completed = false;
 944    while let Some(event) = events.next().await {
 945        match event.unwrap() {
 946            ThreadEvent::ToolCall(tool_call) => {
 947                assert_eq!(tool_call.title, expected_tools.remove(0));
 948                if tool_call.title == "Echo" {
 949                    echo_id = Some(tool_call.id);
 950                }
 951            }
 952            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 953                acp::ToolCallUpdate {
 954                    id,
 955                    fields:
 956                        acp::ToolCallUpdateFields {
 957                            status: Some(acp::ToolCallStatus::Completed),
 958                            ..
 959                        },
 960                },
 961            )) if Some(&id) == echo_id.as_ref() => {
 962                echo_completed = true;
 963            }
 964            _ => {}
 965        }
 966
 967        if expected_tools.is_empty() && echo_completed {
 968            break;
 969        }
 970    }
 971
 972    // Cancel the current send and ensure that the event stream is closed, even
 973    // if one of the tools is still running.
 974    thread.update(cx, |thread, cx| thread.cancel(cx));
 975    let events = events.collect::<Vec<_>>().await;
 976    let last_event = events.last();
 977    assert!(
 978        matches!(
 979            last_event,
 980            Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
 981        ),
 982        "unexpected event {last_event:?}"
 983    );
 984
 985    // Ensure we can still send a new message after cancellation.
 986    let events = thread
 987        .update(cx, |thread, cx| {
 988            thread.send(
 989                UserMessageId::new(),
 990                ["Testing: reply with 'Hello' then stop."],
 991                cx,
 992            )
 993        })
 994        .unwrap()
 995        .collect::<Vec<_>>()
 996        .await;
 997    thread.update(cx, |thread, _cx| {
 998        let message = thread.last_message().unwrap();
 999        let agent_message = message.as_agent_message().unwrap();
1000        assert_eq!(
1001            agent_message.content,
1002            vec![AgentMessageContent::Text("Hello".to_string())]
1003        );
1004    });
1005    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1006}
1007
1008#[gpui::test]
1009async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1010    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1011    let fake_model = model.as_fake();
1012
1013    let events_1 = thread
1014        .update(cx, |thread, cx| {
1015            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1016        })
1017        .unwrap();
1018    cx.run_until_parked();
1019    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1020    cx.run_until_parked();
1021
1022    let events_2 = thread
1023        .update(cx, |thread, cx| {
1024            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1025        })
1026        .unwrap();
1027    cx.run_until_parked();
1028    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1029    fake_model
1030        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1031    fake_model.end_last_completion_stream();
1032
1033    let events_1 = events_1.collect::<Vec<_>>().await;
1034    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
1035    let events_2 = events_2.collect::<Vec<_>>().await;
1036    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1037}
1038
1039#[gpui::test]
1040async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1041    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1042    let fake_model = model.as_fake();
1043
1044    let events_1 = thread
1045        .update(cx, |thread, cx| {
1046            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1047        })
1048        .unwrap();
1049    cx.run_until_parked();
1050    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1051    fake_model
1052        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1053    fake_model.end_last_completion_stream();
1054    let events_1 = events_1.collect::<Vec<_>>().await;
1055
1056    let events_2 = thread
1057        .update(cx, |thread, cx| {
1058            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1059        })
1060        .unwrap();
1061    cx.run_until_parked();
1062    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1063    fake_model
1064        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1065    fake_model.end_last_completion_stream();
1066    let events_2 = events_2.collect::<Vec<_>>().await;
1067
1068    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1069    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1070}
1071
1072#[gpui::test]
1073async fn test_refusal(cx: &mut TestAppContext) {
1074    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1075    let fake_model = model.as_fake();
1076
1077    let events = thread
1078        .update(cx, |thread, cx| {
1079            thread.send(UserMessageId::new(), ["Hello"], cx)
1080        })
1081        .unwrap();
1082    cx.run_until_parked();
1083    thread.read_with(cx, |thread, _| {
1084        assert_eq!(
1085            thread.to_markdown(),
1086            indoc! {"
1087                ## User
1088
1089                Hello
1090            "}
1091        );
1092    });
1093
1094    fake_model.send_last_completion_stream_text_chunk("Hey!");
1095    cx.run_until_parked();
1096    thread.read_with(cx, |thread, _| {
1097        assert_eq!(
1098            thread.to_markdown(),
1099            indoc! {"
1100                ## User
1101
1102                Hello
1103
1104                ## Assistant
1105
1106                Hey!
1107            "}
1108        );
1109    });
1110
1111    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1112    fake_model
1113        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1114    let events = events.collect::<Vec<_>>().await;
1115    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1116    thread.read_with(cx, |thread, _| {
1117        assert_eq!(thread.to_markdown(), "");
1118    });
1119}
1120
1121#[gpui::test]
1122async fn test_truncate(cx: &mut TestAppContext) {
1123    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1124    let fake_model = model.as_fake();
1125
1126    let message_id = UserMessageId::new();
1127    thread
1128        .update(cx, |thread, cx| {
1129            thread.send(message_id.clone(), ["Hello"], cx)
1130        })
1131        .unwrap();
1132    cx.run_until_parked();
1133    thread.read_with(cx, |thread, _| {
1134        assert_eq!(
1135            thread.to_markdown(),
1136            indoc! {"
1137                ## User
1138
1139                Hello
1140            "}
1141        );
1142    });
1143
1144    fake_model.send_last_completion_stream_text_chunk("Hey!");
1145    cx.run_until_parked();
1146    thread.read_with(cx, |thread, _| {
1147        assert_eq!(
1148            thread.to_markdown(),
1149            indoc! {"
1150                ## User
1151
1152                Hello
1153
1154                ## Assistant
1155
1156                Hey!
1157            "}
1158        );
1159    });
1160
1161    thread
1162        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1163        .unwrap();
1164    cx.run_until_parked();
1165    thread.read_with(cx, |thread, _| {
1166        assert_eq!(thread.to_markdown(), "");
1167    });
1168
1169    // Ensure we can still send a new message after truncation.
1170    thread
1171        .update(cx, |thread, cx| {
1172            thread.send(UserMessageId::new(), ["Hi"], cx)
1173        })
1174        .unwrap();
1175    thread.update(cx, |thread, _cx| {
1176        assert_eq!(
1177            thread.to_markdown(),
1178            indoc! {"
1179                ## User
1180
1181                Hi
1182            "}
1183        );
1184    });
1185    cx.run_until_parked();
1186    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1187    cx.run_until_parked();
1188    thread.read_with(cx, |thread, _| {
1189        assert_eq!(
1190            thread.to_markdown(),
1191            indoc! {"
1192                ## User
1193
1194                Hi
1195
1196                ## Assistant
1197
1198                Ahoy!
1199            "}
1200        );
1201    });
1202}
1203
1204#[gpui::test]
1205async fn test_title_generation(cx: &mut TestAppContext) {
1206    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1207    let fake_model = model.as_fake();
1208
1209    let summary_model = Arc::new(FakeLanguageModel::default());
1210    thread.update(cx, |thread, cx| {
1211        thread.set_summarization_model(Some(summary_model.clone()), cx)
1212    });
1213
1214    let send = thread
1215        .update(cx, |thread, cx| {
1216            thread.send(UserMessageId::new(), ["Hello"], cx)
1217        })
1218        .unwrap();
1219    cx.run_until_parked();
1220
1221    fake_model.send_last_completion_stream_text_chunk("Hey!");
1222    fake_model.end_last_completion_stream();
1223    cx.run_until_parked();
1224    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1225
1226    // Ensure the summary model has been invoked to generate a title.
1227    summary_model.send_last_completion_stream_text_chunk("Hello ");
1228    summary_model.send_last_completion_stream_text_chunk("world\nG");
1229    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1230    summary_model.end_last_completion_stream();
1231    send.collect::<Vec<_>>().await;
1232    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1233
1234    // Send another message, ensuring no title is generated this time.
1235    let send = thread
1236        .update(cx, |thread, cx| {
1237            thread.send(UserMessageId::new(), ["Hello again"], cx)
1238        })
1239        .unwrap();
1240    cx.run_until_parked();
1241    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1242    fake_model.end_last_completion_stream();
1243    cx.run_until_parked();
1244    assert_eq!(summary_model.pending_completions(), Vec::new());
1245    send.collect::<Vec<_>>().await;
1246    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1247}
1248
1249#[gpui::test]
1250async fn test_agent_connection(cx: &mut TestAppContext) {
1251    cx.update(settings::init);
1252    let templates = Templates::new();
1253
1254    // Initialize language model system with test provider
1255    cx.update(|cx| {
1256        gpui_tokio::init(cx);
1257        client::init_settings(cx);
1258
1259        let http_client = FakeHttpClient::with_404_response();
1260        let clock = Arc::new(clock::FakeSystemClock::new());
1261        let client = Client::new(clock, http_client, cx);
1262        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1263        language_model::init(client.clone(), cx);
1264        language_models::init(user_store.clone(), client.clone(), cx);
1265        Project::init_settings(cx);
1266        LanguageModelRegistry::test(cx);
1267        agent_settings::init(cx);
1268    });
1269    cx.executor().forbid_parking();
1270
1271    // Create a project for new_thread
1272    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1273    fake_fs.insert_tree(path!("/test"), json!({})).await;
1274    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1275    let cwd = Path::new("/test");
1276
1277    // Create agent and connection
1278    let agent = NativeAgent::new(
1279        project.clone(),
1280        templates.clone(),
1281        None,
1282        fake_fs.clone(),
1283        &mut cx.to_async(),
1284    )
1285    .await
1286    .unwrap();
1287    let connection = NativeAgentConnection(agent.clone());
1288
1289    // Test model_selector returns Some
1290    let selector_opt = connection.model_selector();
1291    assert!(
1292        selector_opt.is_some(),
1293        "agent2 should always support ModelSelector"
1294    );
1295    let selector = selector_opt.unwrap();
1296
1297    // Test list_models
1298    let listed_models = cx
1299        .update(|cx| selector.list_models(cx))
1300        .await
1301        .expect("list_models should succeed");
1302    let AgentModelList::Grouped(listed_models) = listed_models else {
1303        panic!("Unexpected model list type");
1304    };
1305    assert!(!listed_models.is_empty(), "should have at least one model");
1306    assert_eq!(
1307        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1308        "fake/fake"
1309    );
1310
1311    // Create a thread using new_thread
1312    let connection_rc = Rc::new(connection.clone());
1313    let acp_thread = cx
1314        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1315        .await
1316        .expect("new_thread should succeed");
1317
1318    // Get the session_id from the AcpThread
1319    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1320
1321    // Test selected_model returns the default
1322    let model = cx
1323        .update(|cx| selector.selected_model(&session_id, cx))
1324        .await
1325        .expect("selected_model should succeed");
1326    let model = cx
1327        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1328        .unwrap();
1329    let model = model.as_fake();
1330    assert_eq!(model.id().0, "fake", "should return default model");
1331
1332    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1333    cx.run_until_parked();
1334    model.send_last_completion_stream_text_chunk("def");
1335    cx.run_until_parked();
1336    acp_thread.read_with(cx, |thread, cx| {
1337        assert_eq!(
1338            thread.to_markdown(cx),
1339            indoc! {"
1340                ## User
1341
1342                abc
1343
1344                ## Assistant
1345
1346                def
1347
1348            "}
1349        )
1350    });
1351
1352    // Test cancel
1353    cx.update(|cx| connection.cancel(&session_id, cx));
1354    request.await.expect("prompt should fail gracefully");
1355
1356    // Ensure that dropping the ACP thread causes the native thread to be
1357    // dropped as well.
1358    cx.update(|_| drop(acp_thread));
1359    let result = cx
1360        .update(|cx| {
1361            connection.prompt(
1362                Some(acp_thread::UserMessageId::new()),
1363                acp::PromptRequest {
1364                    session_id: session_id.clone(),
1365                    prompt: vec!["ghi".into()],
1366                },
1367                cx,
1368            )
1369        })
1370        .await;
1371    assert_eq!(
1372        result.as_ref().unwrap_err().to_string(),
1373        "Session not found",
1374        "unexpected result: {:?}",
1375        result
1376    );
1377}
1378
1379#[gpui::test]
1380async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1381    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1382    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1383    let fake_model = model.as_fake();
1384
1385    let mut events = thread
1386        .update(cx, |thread, cx| {
1387            thread.send(UserMessageId::new(), ["Think"], cx)
1388        })
1389        .unwrap();
1390    cx.run_until_parked();
1391
1392    // Simulate streaming partial input.
1393    let input = json!({});
1394    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1395        LanguageModelToolUse {
1396            id: "1".into(),
1397            name: ThinkingTool.name().into(),
1398            raw_input: input.to_string(),
1399            input,
1400            is_input_complete: false,
1401        },
1402    ));
1403
1404    // Input streaming completed
1405    let input = json!({ "content": "Thinking hard!" });
1406    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1407        LanguageModelToolUse {
1408            id: "1".into(),
1409            name: "thinking".into(),
1410            raw_input: input.to_string(),
1411            input,
1412            is_input_complete: true,
1413        },
1414    ));
1415    fake_model.end_last_completion_stream();
1416    cx.run_until_parked();
1417
1418    let tool_call = expect_tool_call(&mut events).await;
1419    assert_eq!(
1420        tool_call,
1421        acp::ToolCall {
1422            id: acp::ToolCallId("1".into()),
1423            title: "Thinking".into(),
1424            kind: acp::ToolKind::Think,
1425            status: acp::ToolCallStatus::Pending,
1426            content: vec![],
1427            locations: vec![],
1428            raw_input: Some(json!({})),
1429            raw_output: None,
1430        }
1431    );
1432    let update = expect_tool_call_update_fields(&mut events).await;
1433    assert_eq!(
1434        update,
1435        acp::ToolCallUpdate {
1436            id: acp::ToolCallId("1".into()),
1437            fields: acp::ToolCallUpdateFields {
1438                title: Some("Thinking".into()),
1439                kind: Some(acp::ToolKind::Think),
1440                raw_input: Some(json!({ "content": "Thinking hard!" })),
1441                ..Default::default()
1442            },
1443        }
1444    );
1445    let update = expect_tool_call_update_fields(&mut events).await;
1446    assert_eq!(
1447        update,
1448        acp::ToolCallUpdate {
1449            id: acp::ToolCallId("1".into()),
1450            fields: acp::ToolCallUpdateFields {
1451                status: Some(acp::ToolCallStatus::InProgress),
1452                ..Default::default()
1453            },
1454        }
1455    );
1456    let update = expect_tool_call_update_fields(&mut events).await;
1457    assert_eq!(
1458        update,
1459        acp::ToolCallUpdate {
1460            id: acp::ToolCallId("1".into()),
1461            fields: acp::ToolCallUpdateFields {
1462                content: Some(vec!["Thinking hard!".into()]),
1463                ..Default::default()
1464            },
1465        }
1466    );
1467    let update = expect_tool_call_update_fields(&mut events).await;
1468    assert_eq!(
1469        update,
1470        acp::ToolCallUpdate {
1471            id: acp::ToolCallId("1".into()),
1472            fields: acp::ToolCallUpdateFields {
1473                status: Some(acp::ToolCallStatus::Completed),
1474                raw_output: Some("Finished thinking.".into()),
1475                ..Default::default()
1476            },
1477        }
1478    );
1479}
1480
1481#[gpui::test]
1482async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1483    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1484    let fake_model = model.as_fake();
1485
1486    let mut events = thread
1487        .update(cx, |thread, cx| {
1488            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1489            thread.send(UserMessageId::new(), ["Hello!"], cx)
1490        })
1491        .unwrap();
1492    cx.run_until_parked();
1493
1494    fake_model.send_last_completion_stream_text_chunk("Hey!");
1495    fake_model.end_last_completion_stream();
1496
1497    let mut retry_events = Vec::new();
1498    while let Some(Ok(event)) = events.next().await {
1499        match event {
1500            ThreadEvent::Retry(retry_status) => {
1501                retry_events.push(retry_status);
1502            }
1503            ThreadEvent::Stop(..) => break,
1504            _ => {}
1505        }
1506    }
1507
1508    assert_eq!(retry_events.len(), 0);
1509    thread.read_with(cx, |thread, _cx| {
1510        assert_eq!(
1511            thread.to_markdown(),
1512            indoc! {"
1513                ## User
1514
1515                Hello!
1516
1517                ## Assistant
1518
1519                Hey!
1520            "}
1521        )
1522    });
1523}
1524
1525#[gpui::test]
1526async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1527    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1528    let fake_model = model.as_fake();
1529
1530    let mut events = thread
1531        .update(cx, |thread, cx| {
1532            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1533            thread.send(UserMessageId::new(), ["Hello!"], cx)
1534        })
1535        .unwrap();
1536    cx.run_until_parked();
1537
1538    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1539        provider: LanguageModelProviderName::new("Anthropic"),
1540        retry_after: Some(Duration::from_secs(3)),
1541    });
1542    fake_model.end_last_completion_stream();
1543
1544    cx.executor().advance_clock(Duration::from_secs(3));
1545    cx.run_until_parked();
1546
1547    fake_model.send_last_completion_stream_text_chunk("Hey!");
1548    fake_model.end_last_completion_stream();
1549
1550    let mut retry_events = Vec::new();
1551    while let Some(Ok(event)) = events.next().await {
1552        match event {
1553            ThreadEvent::Retry(retry_status) => {
1554                retry_events.push(retry_status);
1555            }
1556            ThreadEvent::Stop(..) => break,
1557            _ => {}
1558        }
1559    }
1560
1561    assert_eq!(retry_events.len(), 1);
1562    assert!(matches!(
1563        retry_events[0],
1564        acp_thread::RetryStatus { attempt: 1, .. }
1565    ));
1566    thread.read_with(cx, |thread, _cx| {
1567        assert_eq!(
1568            thread.to_markdown(),
1569            indoc! {"
1570                ## User
1571
1572                Hello!
1573
1574                ## Assistant
1575
1576                Hey!
1577            "}
1578        )
1579    });
1580}
1581
1582#[gpui::test]
1583async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1584    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1585    let fake_model = model.as_fake();
1586
1587    let mut events = thread
1588        .update(cx, |thread, cx| {
1589            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1590            thread.send(UserMessageId::new(), ["Hello!"], cx)
1591        })
1592        .unwrap();
1593    cx.run_until_parked();
1594
1595    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1596        fake_model.send_last_completion_stream_error(
1597            LanguageModelCompletionError::ServerOverloaded {
1598                provider: LanguageModelProviderName::new("Anthropic"),
1599                retry_after: Some(Duration::from_secs(3)),
1600            },
1601        );
1602        fake_model.end_last_completion_stream();
1603        cx.executor().advance_clock(Duration::from_secs(3));
1604        cx.run_until_parked();
1605    }
1606
1607    let mut errors = Vec::new();
1608    let mut retry_events = Vec::new();
1609    while let Some(event) = events.next().await {
1610        match event {
1611            Ok(ThreadEvent::Retry(retry_status)) => {
1612                retry_events.push(retry_status);
1613            }
1614            Ok(ThreadEvent::Stop(..)) => break,
1615            Err(error) => errors.push(error),
1616            _ => {}
1617        }
1618    }
1619
1620    assert_eq!(
1621        retry_events.len(),
1622        crate::thread::MAX_RETRY_ATTEMPTS as usize
1623    );
1624    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1625        assert_eq!(retry_events[i].attempt, i + 1);
1626    }
1627    assert_eq!(errors.len(), 1);
1628    let error = errors[0]
1629        .downcast_ref::<LanguageModelCompletionError>()
1630        .unwrap();
1631    assert!(matches!(
1632        error,
1633        LanguageModelCompletionError::ServerOverloaded { .. }
1634    ));
1635}
1636
1637/// Filters out the stop events for asserting against in tests
1638fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1639    result_events
1640        .into_iter()
1641        .filter_map(|event| match event.unwrap() {
1642            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1643            _ => None,
1644        })
1645        .collect()
1646}
1647
1648struct ThreadTest {
1649    model: Arc<dyn LanguageModel>,
1650    thread: Entity<Thread>,
1651    project_context: Entity<ProjectContext>,
1652    fs: Arc<FakeFs>,
1653}
1654
1655enum TestModel {
1656    Sonnet4,
1657    Sonnet4Thinking,
1658    Fake,
1659}
1660
1661impl TestModel {
1662    fn id(&self) -> LanguageModelId {
1663        match self {
1664            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1665            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1666            TestModel::Fake => unreachable!(),
1667        }
1668    }
1669}
1670
1671async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1672    cx.executor().allow_parking();
1673
1674    let fs = FakeFs::new(cx.background_executor.clone());
1675    fs.create_dir(paths::settings_file().parent().unwrap())
1676        .await
1677        .unwrap();
1678    fs.insert_file(
1679        paths::settings_file(),
1680        json!({
1681            "agent": {
1682                "default_profile": "test-profile",
1683                "profiles": {
1684                    "test-profile": {
1685                        "name": "Test Profile",
1686                        "tools": {
1687                            EchoTool.name(): true,
1688                            DelayTool.name(): true,
1689                            WordListTool.name(): true,
1690                            ToolRequiringPermission.name(): true,
1691                            InfiniteTool.name(): true,
1692                        }
1693                    }
1694                }
1695            }
1696        })
1697        .to_string()
1698        .into_bytes(),
1699    )
1700    .await;
1701
1702    cx.update(|cx| {
1703        settings::init(cx);
1704        Project::init_settings(cx);
1705        agent_settings::init(cx);
1706        gpui_tokio::init(cx);
1707        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1708        cx.set_http_client(Arc::new(http_client));
1709
1710        client::init_settings(cx);
1711        let client = Client::production(cx);
1712        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1713        language_model::init(client.clone(), cx);
1714        language_models::init(user_store.clone(), client.clone(), cx);
1715
1716        watch_settings(fs.clone(), cx);
1717    });
1718
1719    let templates = Templates::new();
1720
1721    fs.insert_tree(path!("/test"), json!({})).await;
1722    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1723
1724    let model = cx
1725        .update(|cx| {
1726            if let TestModel::Fake = model {
1727                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1728            } else {
1729                let model_id = model.id();
1730                let models = LanguageModelRegistry::read_global(cx);
1731                let model = models
1732                    .available_models(cx)
1733                    .find(|model| model.id() == model_id)
1734                    .unwrap();
1735
1736                let provider = models.provider(&model.provider_id()).unwrap();
1737                let authenticated = provider.authenticate(cx);
1738
1739                cx.spawn(async move |_cx| {
1740                    authenticated.await.unwrap();
1741                    model
1742                })
1743            }
1744        })
1745        .await;
1746
1747    let project_context = cx.new(|_cx| ProjectContext::default());
1748    let context_server_registry =
1749        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1750    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1751    let thread = cx.new(|cx| {
1752        Thread::new(
1753            project,
1754            project_context.clone(),
1755            context_server_registry,
1756            action_log,
1757            templates,
1758            Some(model.clone()),
1759            None,
1760            cx,
1761        )
1762    });
1763    ThreadTest {
1764        model,
1765        thread,
1766        project_context,
1767        fs,
1768    }
1769}
1770
1771#[cfg(test)]
1772#[ctor::ctor]
1773fn init_logger() {
1774    if std::env::var("RUST_LOG").is_ok() {
1775        env_logger::init();
1776    }
1777}
1778
1779fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1780    let fs = fs.clone();
1781    cx.spawn({
1782        async move |cx| {
1783            let mut new_settings_content_rx = settings::watch_config_file(
1784                cx.background_executor(),
1785                fs,
1786                paths::settings_file().clone(),
1787            );
1788
1789            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1790                cx.update(|cx| {
1791                    SettingsStore::update_global(cx, |settings, cx| {
1792                        settings.set_user_settings(&new_settings_content, cx)
1793                    })
1794                })
1795                .ok();
1796            }
1797        }
1798    })
1799    .detach();
1800}