mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  17    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use pretty_assertions::assert_eq;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use reqwest_client::ReqwestClient;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .unwrap()
  44        .collect()
  45        .await;
  46    thread.update(cx, |thread, _cx| {
  47        assert_eq!(
  48            thread.last_message().unwrap().to_markdown(),
  49            indoc! {"
  50                ## Assistant
  51
  52                Hello
  53            "}
  54        )
  55    });
  56    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  57}
  58
  59#[gpui::test]
  60#[ignore = "can't run on CI yet"]
  61async fn test_thinking(cx: &mut TestAppContext) {
  62    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  63
  64    let events = thread
  65        .update(cx, |thread, cx| {
  66            thread.send(
  67                UserMessageId::new(),
  68                [indoc! {"
  69                    Testing:
  70
  71                    Generate a thinking step where you just think the word 'Think',
  72                    and have your final answer be 'Hello'
  73                "}],
  74                cx,
  75            )
  76        })
  77        .unwrap()
  78        .collect()
  79        .await;
  80    thread.update(cx, |thread, _cx| {
  81        assert_eq!(
  82            thread.last_message().unwrap().to_markdown(),
  83            indoc! {"
  84                ## Assistant
  85
  86                <think>Think</think>
  87                Hello
  88            "}
  89        )
  90    });
  91    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  92}
  93
  94#[gpui::test]
  95async fn test_system_prompt(cx: &mut TestAppContext) {
  96    let ThreadTest {
  97        model,
  98        thread,
  99        project_context,
 100        ..
 101    } = setup(cx, TestModel::Fake).await;
 102    let fake_model = model.as_fake();
 103
 104    project_context.update(cx, |project_context, _cx| {
 105        project_context.shell = "test-shell".into()
 106    });
 107    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 108    thread
 109        .update(cx, |thread, cx| {
 110            thread.send(UserMessageId::new(), ["abc"], cx)
 111        })
 112        .unwrap();
 113    cx.run_until_parked();
 114    let mut pending_completions = fake_model.pending_completions();
 115    assert_eq!(
 116        pending_completions.len(),
 117        1,
 118        "unexpected pending completions: {:?}",
 119        pending_completions
 120    );
 121
 122    let pending_completion = pending_completions.pop().unwrap();
 123    assert_eq!(pending_completion.messages[0].role, Role::System);
 124
 125    let system_message = &pending_completion.messages[0];
 126    let system_prompt = system_message.content[0].to_str().unwrap();
 127    assert!(
 128        system_prompt.contains("test-shell"),
 129        "unexpected system message: {:?}",
 130        system_message
 131    );
 132    assert!(
 133        system_prompt.contains("## Fixing Diagnostics"),
 134        "unexpected system message: {:?}",
 135        system_message
 136    );
 137}
 138
 139#[gpui::test]
 140async fn test_prompt_caching(cx: &mut TestAppContext) {
 141    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 142    let fake_model = model.as_fake();
 143
 144    // Send initial user message and verify it's cached
 145    thread
 146        .update(cx, |thread, cx| {
 147            thread.send(UserMessageId::new(), ["Message 1"], cx)
 148        })
 149        .unwrap();
 150    cx.run_until_parked();
 151
 152    let completion = fake_model.pending_completions().pop().unwrap();
 153    assert_eq!(
 154        completion.messages[1..],
 155        vec![LanguageModelRequestMessage {
 156            role: Role::User,
 157            content: vec!["Message 1".into()],
 158            cache: true
 159        }]
 160    );
 161    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 162        "Response to Message 1".into(),
 163    ));
 164    fake_model.end_last_completion_stream();
 165    cx.run_until_parked();
 166
 167    // Send another user message and verify only the latest is cached
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["Message 2"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174
 175    let completion = fake_model.pending_completions().pop().unwrap();
 176    assert_eq!(
 177        completion.messages[1..],
 178        vec![
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 1".into()],
 182                cache: false
 183            },
 184            LanguageModelRequestMessage {
 185                role: Role::Assistant,
 186                content: vec!["Response to Message 1".into()],
 187                cache: false
 188            },
 189            LanguageModelRequestMessage {
 190                role: Role::User,
 191                content: vec!["Message 2".into()],
 192                cache: true
 193            }
 194        ]
 195    );
 196    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 197        "Response to Message 2".into(),
 198    ));
 199    fake_model.end_last_completion_stream();
 200    cx.run_until_parked();
 201
 202    // Simulate a tool call and verify that the latest tool result is cached
 203    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 204    thread
 205        .update(cx, |thread, cx| {
 206            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 207        })
 208        .unwrap();
 209    cx.run_until_parked();
 210
 211    let tool_use = LanguageModelToolUse {
 212        id: "tool_1".into(),
 213        name: EchoTool.name().into(),
 214        raw_input: json!({"text": "test"}).to_string(),
 215        input: json!({"text": "test"}),
 216        is_input_complete: true,
 217    };
 218    fake_model
 219        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 220    fake_model.end_last_completion_stream();
 221    cx.run_until_parked();
 222
 223    let completion = fake_model.pending_completions().pop().unwrap();
 224    let tool_result = LanguageModelToolResult {
 225        tool_use_id: "tool_1".into(),
 226        tool_name: EchoTool.name().into(),
 227        is_error: false,
 228        content: "test".into(),
 229        output: Some("test".into()),
 230    };
 231    assert_eq!(
 232        completion.messages[1..],
 233        vec![
 234            LanguageModelRequestMessage {
 235                role: Role::User,
 236                content: vec!["Message 1".into()],
 237                cache: false
 238            },
 239            LanguageModelRequestMessage {
 240                role: Role::Assistant,
 241                content: vec!["Response to Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::User,
 246                content: vec!["Message 2".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::Assistant,
 251                content: vec!["Response to Message 2".into()],
 252                cache: false
 253            },
 254            LanguageModelRequestMessage {
 255                role: Role::User,
 256                content: vec!["Use the echo tool".into()],
 257                cache: false
 258            },
 259            LanguageModelRequestMessage {
 260                role: Role::Assistant,
 261                content: vec![MessageContent::ToolUse(tool_use)],
 262                cache: false
 263            },
 264            LanguageModelRequestMessage {
 265                role: Role::User,
 266                content: vec![MessageContent::ToolResult(tool_result)],
 267                cache: true
 268            }
 269        ]
 270    );
 271}
 272
 273#[gpui::test]
 274#[ignore = "can't run on CI yet"]
 275async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 276    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 277
 278    // Test a tool call that's likely to complete *before* streaming stops.
 279    let events = thread
 280        .update(cx, |thread, cx| {
 281            thread.add_tool(EchoTool);
 282            thread.send(
 283                UserMessageId::new(),
 284                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 285                cx,
 286            )
 287        })
 288        .unwrap()
 289        .collect()
 290        .await;
 291    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 292
 293    // Test a tool calls that's likely to complete *after* streaming stops.
 294    let events = thread
 295        .update(cx, |thread, cx| {
 296            thread.remove_tool(&AgentTool::name(&EchoTool));
 297            thread.add_tool(DelayTool);
 298            thread.send(
 299                UserMessageId::new(),
 300                [
 301                    "Now call the delay tool with 200ms.",
 302                    "When the timer goes off, then you echo the output of the tool.",
 303                ],
 304                cx,
 305            )
 306        })
 307        .unwrap()
 308        .collect()
 309        .await;
 310    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 311    thread.update(cx, |thread, _cx| {
 312        assert!(
 313            thread
 314                .last_message()
 315                .unwrap()
 316                .as_agent_message()
 317                .unwrap()
 318                .content
 319                .iter()
 320                .any(|content| {
 321                    if let AgentMessageContent::Text(text) = content {
 322                        text.contains("Ding")
 323                    } else {
 324                        false
 325                    }
 326                }),
 327            "{}",
 328            thread.to_markdown()
 329        );
 330    });
 331}
 332
 333#[gpui::test]
 334#[ignore = "can't run on CI yet"]
 335async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let mut events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(WordListTool);
 342            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 343        })
 344        .unwrap();
 345
 346    let mut saw_partial_tool_use = false;
 347    while let Some(event) = events.next().await {
 348        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 349            thread.update(cx, |thread, _cx| {
 350                // Look for a tool use in the thread's last message
 351                let message = thread.last_message().unwrap();
 352                let agent_message = message.as_agent_message().unwrap();
 353                let last_content = agent_message.content.last().unwrap();
 354                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 355                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 356                    if tool_call.status == acp::ToolCallStatus::Pending {
 357                        if !last_tool_use.is_input_complete
 358                            && last_tool_use.input.get("g").is_none()
 359                        {
 360                            saw_partial_tool_use = true;
 361                        }
 362                    } else {
 363                        last_tool_use
 364                            .input
 365                            .get("a")
 366                            .expect("'a' has streamed because input is now complete");
 367                        last_tool_use
 368                            .input
 369                            .get("g")
 370                            .expect("'g' has streamed because input is now complete");
 371                    }
 372                } else {
 373                    panic!("last content should be a tool use");
 374                }
 375            });
 376        }
 377    }
 378
 379    assert!(
 380        saw_partial_tool_use,
 381        "should see at least one partially streamed tool use in the history"
 382    );
 383}
 384
 385#[gpui::test]
 386async fn test_tool_authorization(cx: &mut TestAppContext) {
 387    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 388    let fake_model = model.as_fake();
 389
 390    let mut events = thread
 391        .update(cx, |thread, cx| {
 392            thread.add_tool(ToolRequiringPermission);
 393            thread.send(UserMessageId::new(), ["abc"], cx)
 394        })
 395        .unwrap();
 396    cx.run_until_parked();
 397    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 398        LanguageModelToolUse {
 399            id: "tool_id_1".into(),
 400            name: ToolRequiringPermission.name().into(),
 401            raw_input: "{}".into(),
 402            input: json!({}),
 403            is_input_complete: true,
 404        },
 405    ));
 406    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 407        LanguageModelToolUse {
 408            id: "tool_id_2".into(),
 409            name: ToolRequiringPermission.name().into(),
 410            raw_input: "{}".into(),
 411            input: json!({}),
 412            is_input_complete: true,
 413        },
 414    ));
 415    fake_model.end_last_completion_stream();
 416    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 417    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 418
 419    // Approve the first
 420    tool_call_auth_1
 421        .response
 422        .send(tool_call_auth_1.options[1].id.clone())
 423        .unwrap();
 424    cx.run_until_parked();
 425
 426    // Reject the second
 427    tool_call_auth_2
 428        .response
 429        .send(tool_call_auth_1.options[2].id.clone())
 430        .unwrap();
 431    cx.run_until_parked();
 432
 433    let completion = fake_model.pending_completions().pop().unwrap();
 434    let message = completion.messages.last().unwrap();
 435    assert_eq!(
 436        message.content,
 437        vec![
 438            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 439                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 440                tool_name: ToolRequiringPermission.name().into(),
 441                is_error: false,
 442                content: "Allowed".into(),
 443                output: Some("Allowed".into())
 444            }),
 445            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 446                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 447                tool_name: ToolRequiringPermission.name().into(),
 448                is_error: true,
 449                content: "Permission to run tool denied by user".into(),
 450                output: None
 451            })
 452        ]
 453    );
 454
 455    // Simulate yet another tool call.
 456    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 457        LanguageModelToolUse {
 458            id: "tool_id_3".into(),
 459            name: ToolRequiringPermission.name().into(),
 460            raw_input: "{}".into(),
 461            input: json!({}),
 462            is_input_complete: true,
 463        },
 464    ));
 465    fake_model.end_last_completion_stream();
 466
 467    // Respond by always allowing tools.
 468    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 469    tool_call_auth_3
 470        .response
 471        .send(tool_call_auth_3.options[0].id.clone())
 472        .unwrap();
 473    cx.run_until_parked();
 474    let completion = fake_model.pending_completions().pop().unwrap();
 475    let message = completion.messages.last().unwrap();
 476    assert_eq!(
 477        message.content,
 478        vec![language_model::MessageContent::ToolResult(
 479            LanguageModelToolResult {
 480                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 481                tool_name: ToolRequiringPermission.name().into(),
 482                is_error: false,
 483                content: "Allowed".into(),
 484                output: Some("Allowed".into())
 485            }
 486        )]
 487    );
 488
 489    // Simulate a final tool call, ensuring we don't trigger authorization.
 490    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 491        LanguageModelToolUse {
 492            id: "tool_id_4".into(),
 493            name: ToolRequiringPermission.name().into(),
 494            raw_input: "{}".into(),
 495            input: json!({}),
 496            is_input_complete: true,
 497        },
 498    ));
 499    fake_model.end_last_completion_stream();
 500    cx.run_until_parked();
 501    let completion = fake_model.pending_completions().pop().unwrap();
 502    let message = completion.messages.last().unwrap();
 503    assert_eq!(
 504        message.content,
 505        vec![language_model::MessageContent::ToolResult(
 506            LanguageModelToolResult {
 507                tool_use_id: "tool_id_4".into(),
 508                tool_name: ToolRequiringPermission.name().into(),
 509                is_error: false,
 510                content: "Allowed".into(),
 511                output: Some("Allowed".into())
 512            }
 513        )]
 514    );
 515}
 516
 517#[gpui::test]
 518async fn test_tool_hallucination(cx: &mut TestAppContext) {
 519    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 520    let fake_model = model.as_fake();
 521
 522    let mut events = thread
 523        .update(cx, |thread, cx| {
 524            thread.send(UserMessageId::new(), ["abc"], cx)
 525        })
 526        .unwrap();
 527    cx.run_until_parked();
 528    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 529        LanguageModelToolUse {
 530            id: "tool_id_1".into(),
 531            name: "nonexistent_tool".into(),
 532            raw_input: "{}".into(),
 533            input: json!({}),
 534            is_input_complete: true,
 535        },
 536    ));
 537    fake_model.end_last_completion_stream();
 538
 539    let tool_call = expect_tool_call(&mut events).await;
 540    assert_eq!(tool_call.title, "nonexistent_tool");
 541    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 542    let update = expect_tool_call_update_fields(&mut events).await;
 543    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 544}
 545
 546#[gpui::test]
 547async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 548    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 549    let fake_model = model.as_fake();
 550
 551    let events = thread
 552        .update(cx, |thread, cx| {
 553            thread.add_tool(EchoTool);
 554            thread.send(UserMessageId::new(), ["abc"], cx)
 555        })
 556        .unwrap();
 557    cx.run_until_parked();
 558    let tool_use = LanguageModelToolUse {
 559        id: "tool_id_1".into(),
 560        name: EchoTool.name().into(),
 561        raw_input: "{}".into(),
 562        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 563        is_input_complete: true,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568
 569    cx.run_until_parked();
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_id_1".into(),
 573        tool_name: EchoTool.name().into(),
 574        is_error: false,
 575        content: "def".into(),
 576        output: Some("def".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["abc".into()],
 584                cache: false
 585            },
 586            LanguageModelRequestMessage {
 587                role: Role::Assistant,
 588                content: vec![MessageContent::ToolUse(tool_use.clone())],
 589                cache: false
 590            },
 591            LanguageModelRequestMessage {
 592                role: Role::User,
 593                content: vec![MessageContent::ToolResult(tool_result.clone())],
 594                cache: true
 595            },
 596        ]
 597    );
 598
 599    // Simulate reaching tool use limit.
 600    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 601        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 602    ));
 603    fake_model.end_last_completion_stream();
 604    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 605    assert!(
 606        last_event
 607            .unwrap_err()
 608            .is::<language_model::ToolUseLimitReachedError>()
 609    );
 610
 611    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 612    cx.run_until_parked();
 613    let completion = fake_model.pending_completions().pop().unwrap();
 614    assert_eq!(
 615        completion.messages[1..],
 616        vec![
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec!["abc".into()],
 620                cache: false
 621            },
 622            LanguageModelRequestMessage {
 623                role: Role::Assistant,
 624                content: vec![MessageContent::ToolUse(tool_use)],
 625                cache: false
 626            },
 627            LanguageModelRequestMessage {
 628                role: Role::User,
 629                content: vec![MessageContent::ToolResult(tool_result)],
 630                cache: false
 631            },
 632            LanguageModelRequestMessage {
 633                role: Role::User,
 634                content: vec!["Continue where you left off".into()],
 635                cache: true
 636            }
 637        ]
 638    );
 639
 640    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 641    fake_model.end_last_completion_stream();
 642    events.collect::<Vec<_>>().await;
 643    thread.read_with(cx, |thread, _cx| {
 644        assert_eq!(
 645            thread.last_message().unwrap().to_markdown(),
 646            indoc! {"
 647                ## Assistant
 648
 649                Done
 650            "}
 651        )
 652    });
 653
 654    // Ensure we error if calling resume when tool use limit was *not* reached.
 655    let error = thread
 656        .update(cx, |thread, cx| thread.resume(cx))
 657        .unwrap_err();
 658    assert_eq!(
 659        error.to_string(),
 660        "can only resume after tool use limit is reached"
 661    )
 662}
 663
 664#[gpui::test]
 665async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 666    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 667    let fake_model = model.as_fake();
 668
 669    let events = thread
 670        .update(cx, |thread, cx| {
 671            thread.add_tool(EchoTool);
 672            thread.send(UserMessageId::new(), ["abc"], cx)
 673        })
 674        .unwrap();
 675    cx.run_until_parked();
 676
 677    let tool_use = LanguageModelToolUse {
 678        id: "tool_id_1".into(),
 679        name: EchoTool.name().into(),
 680        raw_input: "{}".into(),
 681        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 682        is_input_complete: true,
 683    };
 684    let tool_result = LanguageModelToolResult {
 685        tool_use_id: "tool_id_1".into(),
 686        tool_name: EchoTool.name().into(),
 687        is_error: false,
 688        content: "def".into(),
 689        output: Some("def".into()),
 690    };
 691    fake_model
 692        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 693    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 694        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 695    ));
 696    fake_model.end_last_completion_stream();
 697    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 698    assert!(
 699        last_event
 700            .unwrap_err()
 701            .is::<language_model::ToolUseLimitReachedError>()
 702    );
 703
 704    thread
 705        .update(cx, |thread, cx| {
 706            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 707        })
 708        .unwrap();
 709    cx.run_until_parked();
 710    let completion = fake_model.pending_completions().pop().unwrap();
 711    assert_eq!(
 712        completion.messages[1..],
 713        vec![
 714            LanguageModelRequestMessage {
 715                role: Role::User,
 716                content: vec!["abc".into()],
 717                cache: false
 718            },
 719            LanguageModelRequestMessage {
 720                role: Role::Assistant,
 721                content: vec![MessageContent::ToolUse(tool_use)],
 722                cache: false
 723            },
 724            LanguageModelRequestMessage {
 725                role: Role::User,
 726                content: vec![MessageContent::ToolResult(tool_result)],
 727                cache: false
 728            },
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["ghi".into()],
 732                cache: true
 733            }
 734        ]
 735    );
 736}
 737
 738async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 739    let event = events
 740        .next()
 741        .await
 742        .expect("no tool call authorization event received")
 743        .unwrap();
 744    match event {
 745        ThreadEvent::ToolCall(tool_call) => tool_call,
 746        event => {
 747            panic!("Unexpected event {event:?}");
 748        }
 749    }
 750}
 751
 752async fn expect_tool_call_update_fields(
 753    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 754) -> acp::ToolCallUpdate {
 755    let event = events
 756        .next()
 757        .await
 758        .expect("no tool call authorization event received")
 759        .unwrap();
 760    match event {
 761        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 762        event => {
 763            panic!("Unexpected event {event:?}");
 764        }
 765    }
 766}
 767
 768async fn next_tool_call_authorization(
 769    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 770) -> ToolCallAuthorization {
 771    loop {
 772        let event = events
 773            .next()
 774            .await
 775            .expect("no tool call authorization event received")
 776            .unwrap();
 777        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 778            let permission_kinds = tool_call_authorization
 779                .options
 780                .iter()
 781                .map(|o| o.kind)
 782                .collect::<Vec<_>>();
 783            assert_eq!(
 784                permission_kinds,
 785                vec![
 786                    acp::PermissionOptionKind::AllowAlways,
 787                    acp::PermissionOptionKind::AllowOnce,
 788                    acp::PermissionOptionKind::RejectOnce,
 789                ]
 790            );
 791            return tool_call_authorization;
 792        }
 793    }
 794}
 795
 796#[gpui::test]
 797#[ignore = "can't run on CI yet"]
 798async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 799    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 800
 801    // Test concurrent tool calls with different delay times
 802    let events = thread
 803        .update(cx, |thread, cx| {
 804            thread.add_tool(DelayTool);
 805            thread.send(
 806                UserMessageId::new(),
 807                [
 808                    "Call the delay tool twice in the same message.",
 809                    "Once with 100ms. Once with 300ms.",
 810                    "When both timers are complete, describe the outputs.",
 811                ],
 812                cx,
 813            )
 814        })
 815        .unwrap()
 816        .collect()
 817        .await;
 818
 819    let stop_reasons = stop_events(events);
 820    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 821
 822    thread.update(cx, |thread, _cx| {
 823        let last_message = thread.last_message().unwrap();
 824        let agent_message = last_message.as_agent_message().unwrap();
 825        let text = agent_message
 826            .content
 827            .iter()
 828            .filter_map(|content| {
 829                if let AgentMessageContent::Text(text) = content {
 830                    Some(text.as_str())
 831                } else {
 832                    None
 833                }
 834            })
 835            .collect::<String>();
 836
 837        assert!(text.contains("Ding"));
 838    });
 839}
 840
 841#[gpui::test]
 842async fn test_profiles(cx: &mut TestAppContext) {
 843    let ThreadTest {
 844        model, thread, fs, ..
 845    } = setup(cx, TestModel::Fake).await;
 846    let fake_model = model.as_fake();
 847
 848    thread.update(cx, |thread, _cx| {
 849        thread.add_tool(DelayTool);
 850        thread.add_tool(EchoTool);
 851        thread.add_tool(InfiniteTool);
 852    });
 853
 854    // Override profiles and wait for settings to be loaded.
 855    fs.insert_file(
 856        paths::settings_file(),
 857        json!({
 858            "agent": {
 859                "profiles": {
 860                    "test-1": {
 861                        "name": "Test Profile 1",
 862                        "tools": {
 863                            EchoTool.name(): true,
 864                            DelayTool.name(): true,
 865                        }
 866                    },
 867                    "test-2": {
 868                        "name": "Test Profile 2",
 869                        "tools": {
 870                            InfiniteTool.name(): true,
 871                        }
 872                    }
 873                }
 874            }
 875        })
 876        .to_string()
 877        .into_bytes(),
 878    )
 879    .await;
 880    cx.run_until_parked();
 881
 882    // Test that test-1 profile (default) has echo and delay tools
 883    thread
 884        .update(cx, |thread, cx| {
 885            thread.set_profile(AgentProfileId("test-1".into()));
 886            thread.send(UserMessageId::new(), ["test"], cx)
 887        })
 888        .unwrap();
 889    cx.run_until_parked();
 890
 891    let mut pending_completions = fake_model.pending_completions();
 892    assert_eq!(pending_completions.len(), 1);
 893    let completion = pending_completions.pop().unwrap();
 894    let tool_names: Vec<String> = completion
 895        .tools
 896        .iter()
 897        .map(|tool| tool.name.clone())
 898        .collect();
 899    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 900    fake_model.end_last_completion_stream();
 901
 902    // Switch to test-2 profile, and verify that it has only the infinite tool.
 903    thread
 904        .update(cx, |thread, cx| {
 905            thread.set_profile(AgentProfileId("test-2".into()));
 906            thread.send(UserMessageId::new(), ["test2"], cx)
 907        })
 908        .unwrap();
 909    cx.run_until_parked();
 910    let mut pending_completions = fake_model.pending_completions();
 911    assert_eq!(pending_completions.len(), 1);
 912    let completion = pending_completions.pop().unwrap();
 913    let tool_names: Vec<String> = completion
 914        .tools
 915        .iter()
 916        .map(|tool| tool.name.clone())
 917        .collect();
 918    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 919}
 920
 921#[gpui::test]
 922#[ignore = "can't run on CI yet"]
 923async fn test_cancellation(cx: &mut TestAppContext) {
 924    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 925
 926    let mut events = thread
 927        .update(cx, |thread, cx| {
 928            thread.add_tool(InfiniteTool);
 929            thread.add_tool(EchoTool);
 930            thread.send(
 931                UserMessageId::new(),
 932                ["Call the echo tool, then call the infinite tool, then explain their output"],
 933                cx,
 934            )
 935        })
 936        .unwrap();
 937
 938    // Wait until both tools are called.
 939    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 940    let mut echo_id = None;
 941    let mut echo_completed = false;
 942    while let Some(event) = events.next().await {
 943        match event.unwrap() {
 944            ThreadEvent::ToolCall(tool_call) => {
 945                assert_eq!(tool_call.title, expected_tools.remove(0));
 946                if tool_call.title == "Echo" {
 947                    echo_id = Some(tool_call.id);
 948                }
 949            }
 950            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 951                acp::ToolCallUpdate {
 952                    id,
 953                    fields:
 954                        acp::ToolCallUpdateFields {
 955                            status: Some(acp::ToolCallStatus::Completed),
 956                            ..
 957                        },
 958                },
 959            )) if Some(&id) == echo_id.as_ref() => {
 960                echo_completed = true;
 961            }
 962            _ => {}
 963        }
 964
 965        if expected_tools.is_empty() && echo_completed {
 966            break;
 967        }
 968    }
 969
 970    // Cancel the current send and ensure that the event stream is closed, even
 971    // if one of the tools is still running.
 972    thread.update(cx, |thread, cx| thread.cancel(cx));
 973    let events = events.collect::<Vec<_>>().await;
 974    let last_event = events.last();
 975    assert!(
 976        matches!(
 977            last_event,
 978            Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
 979        ),
 980        "unexpected event {last_event:?}"
 981    );
 982
 983    // Ensure we can still send a new message after cancellation.
 984    let events = thread
 985        .update(cx, |thread, cx| {
 986            thread.send(
 987                UserMessageId::new(),
 988                ["Testing: reply with 'Hello' then stop."],
 989                cx,
 990            )
 991        })
 992        .unwrap()
 993        .collect::<Vec<_>>()
 994        .await;
 995    thread.update(cx, |thread, _cx| {
 996        let message = thread.last_message().unwrap();
 997        let agent_message = message.as_agent_message().unwrap();
 998        assert_eq!(
 999            agent_message.content,
1000            vec![AgentMessageContent::Text("Hello".to_string())]
1001        );
1002    });
1003    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1004}
1005
1006#[gpui::test]
1007async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1008    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1009    let fake_model = model.as_fake();
1010
1011    let events_1 = thread
1012        .update(cx, |thread, cx| {
1013            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1014        })
1015        .unwrap();
1016    cx.run_until_parked();
1017    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1018    cx.run_until_parked();
1019
1020    let events_2 = thread
1021        .update(cx, |thread, cx| {
1022            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1023        })
1024        .unwrap();
1025    cx.run_until_parked();
1026    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1027    fake_model
1028        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1029    fake_model.end_last_completion_stream();
1030
1031    let events_1 = events_1.collect::<Vec<_>>().await;
1032    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
1033    let events_2 = events_2.collect::<Vec<_>>().await;
1034    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1035}
1036
1037#[gpui::test]
1038async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1039    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1040    let fake_model = model.as_fake();
1041
1042    let events_1 = thread
1043        .update(cx, |thread, cx| {
1044            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1045        })
1046        .unwrap();
1047    cx.run_until_parked();
1048    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1049    fake_model
1050        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1051    fake_model.end_last_completion_stream();
1052    let events_1 = events_1.collect::<Vec<_>>().await;
1053
1054    let events_2 = thread
1055        .update(cx, |thread, cx| {
1056            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1057        })
1058        .unwrap();
1059    cx.run_until_parked();
1060    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1061    fake_model
1062        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1063    fake_model.end_last_completion_stream();
1064    let events_2 = events_2.collect::<Vec<_>>().await;
1065
1066    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1067    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1068}
1069
1070#[gpui::test]
1071async fn test_refusal(cx: &mut TestAppContext) {
1072    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1073    let fake_model = model.as_fake();
1074
1075    let events = thread
1076        .update(cx, |thread, cx| {
1077            thread.send(UserMessageId::new(), ["Hello"], cx)
1078        })
1079        .unwrap();
1080    cx.run_until_parked();
1081    thread.read_with(cx, |thread, _| {
1082        assert_eq!(
1083            thread.to_markdown(),
1084            indoc! {"
1085                ## User
1086
1087                Hello
1088            "}
1089        );
1090    });
1091
1092    fake_model.send_last_completion_stream_text_chunk("Hey!");
1093    cx.run_until_parked();
1094    thread.read_with(cx, |thread, _| {
1095        assert_eq!(
1096            thread.to_markdown(),
1097            indoc! {"
1098                ## User
1099
1100                Hello
1101
1102                ## Assistant
1103
1104                Hey!
1105            "}
1106        );
1107    });
1108
1109    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1110    fake_model
1111        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1112    let events = events.collect::<Vec<_>>().await;
1113    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1114    thread.read_with(cx, |thread, _| {
1115        assert_eq!(thread.to_markdown(), "");
1116    });
1117}
1118
1119#[gpui::test]
1120async fn test_truncate(cx: &mut TestAppContext) {
1121    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1122    let fake_model = model.as_fake();
1123
1124    let message_id = UserMessageId::new();
1125    thread
1126        .update(cx, |thread, cx| {
1127            thread.send(message_id.clone(), ["Hello"], cx)
1128        })
1129        .unwrap();
1130    cx.run_until_parked();
1131    thread.read_with(cx, |thread, _| {
1132        assert_eq!(
1133            thread.to_markdown(),
1134            indoc! {"
1135                ## User
1136
1137                Hello
1138            "}
1139        );
1140    });
1141
1142    fake_model.send_last_completion_stream_text_chunk("Hey!");
1143    cx.run_until_parked();
1144    thread.read_with(cx, |thread, _| {
1145        assert_eq!(
1146            thread.to_markdown(),
1147            indoc! {"
1148                ## User
1149
1150                Hello
1151
1152                ## Assistant
1153
1154                Hey!
1155            "}
1156        );
1157    });
1158
1159    thread
1160        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1161        .unwrap();
1162    cx.run_until_parked();
1163    thread.read_with(cx, |thread, _| {
1164        assert_eq!(thread.to_markdown(), "");
1165    });
1166
1167    // Ensure we can still send a new message after truncation.
1168    thread
1169        .update(cx, |thread, cx| {
1170            thread.send(UserMessageId::new(), ["Hi"], cx)
1171        })
1172        .unwrap();
1173    thread.update(cx, |thread, _cx| {
1174        assert_eq!(
1175            thread.to_markdown(),
1176            indoc! {"
1177                ## User
1178
1179                Hi
1180            "}
1181        );
1182    });
1183    cx.run_until_parked();
1184    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1185    cx.run_until_parked();
1186    thread.read_with(cx, |thread, _| {
1187        assert_eq!(
1188            thread.to_markdown(),
1189            indoc! {"
1190                ## User
1191
1192                Hi
1193
1194                ## Assistant
1195
1196                Ahoy!
1197            "}
1198        );
1199    });
1200}
1201
1202#[gpui::test]
1203async fn test_title_generation(cx: &mut TestAppContext) {
1204    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1205    let fake_model = model.as_fake();
1206
1207    let summary_model = Arc::new(FakeLanguageModel::default());
1208    thread.update(cx, |thread, cx| {
1209        thread.set_summarization_model(Some(summary_model.clone()), cx)
1210    });
1211
1212    let send = thread
1213        .update(cx, |thread, cx| {
1214            thread.send(UserMessageId::new(), ["Hello"], cx)
1215        })
1216        .unwrap();
1217    cx.run_until_parked();
1218
1219    fake_model.send_last_completion_stream_text_chunk("Hey!");
1220    fake_model.end_last_completion_stream();
1221    cx.run_until_parked();
1222    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1223
1224    // Ensure the summary model has been invoked to generate a title.
1225    summary_model.send_last_completion_stream_text_chunk("Hello ");
1226    summary_model.send_last_completion_stream_text_chunk("world\nG");
1227    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1228    summary_model.end_last_completion_stream();
1229    send.collect::<Vec<_>>().await;
1230    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1231
1232    // Send another message, ensuring no title is generated this time.
1233    let send = thread
1234        .update(cx, |thread, cx| {
1235            thread.send(UserMessageId::new(), ["Hello again"], cx)
1236        })
1237        .unwrap();
1238    cx.run_until_parked();
1239    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1240    fake_model.end_last_completion_stream();
1241    cx.run_until_parked();
1242    assert_eq!(summary_model.pending_completions(), Vec::new());
1243    send.collect::<Vec<_>>().await;
1244    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1245}
1246
1247#[gpui::test]
1248async fn test_agent_connection(cx: &mut TestAppContext) {
1249    cx.update(settings::init);
1250    let templates = Templates::new();
1251
1252    // Initialize language model system with test provider
1253    cx.update(|cx| {
1254        gpui_tokio::init(cx);
1255        client::init_settings(cx);
1256
1257        let http_client = FakeHttpClient::with_404_response();
1258        let clock = Arc::new(clock::FakeSystemClock::new());
1259        let client = Client::new(clock, http_client, cx);
1260        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1261        language_model::init(client.clone(), cx);
1262        language_models::init(user_store.clone(), client.clone(), cx);
1263        Project::init_settings(cx);
1264        LanguageModelRegistry::test(cx);
1265        agent_settings::init(cx);
1266    });
1267    cx.executor().forbid_parking();
1268
1269    // Create a project for new_thread
1270    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1271    fake_fs.insert_tree(path!("/test"), json!({})).await;
1272    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1273    let cwd = Path::new("/test");
1274    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1275    let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
1276
1277    // Create agent and connection
1278    let agent = NativeAgent::new(
1279        project.clone(),
1280        history_store,
1281        templates.clone(),
1282        None,
1283        fake_fs.clone(),
1284        &mut cx.to_async(),
1285    )
1286    .await
1287    .unwrap();
1288    let connection = NativeAgentConnection(agent.clone());
1289
1290    // Test model_selector returns Some
1291    let selector_opt = connection.model_selector();
1292    assert!(
1293        selector_opt.is_some(),
1294        "agent2 should always support ModelSelector"
1295    );
1296    let selector = selector_opt.unwrap();
1297
1298    // Test list_models
1299    let listed_models = cx
1300        .update(|cx| selector.list_models(cx))
1301        .await
1302        .expect("list_models should succeed");
1303    let AgentModelList::Grouped(listed_models) = listed_models else {
1304        panic!("Unexpected model list type");
1305    };
1306    assert!(!listed_models.is_empty(), "should have at least one model");
1307    assert_eq!(
1308        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1309        "fake/fake"
1310    );
1311
1312    // Create a thread using new_thread
1313    let connection_rc = Rc::new(connection.clone());
1314    let acp_thread = cx
1315        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1316        .await
1317        .expect("new_thread should succeed");
1318
1319    // Get the session_id from the AcpThread
1320    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1321
1322    // Test selected_model returns the default
1323    let model = cx
1324        .update(|cx| selector.selected_model(&session_id, cx))
1325        .await
1326        .expect("selected_model should succeed");
1327    let model = cx
1328        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1329        .unwrap();
1330    let model = model.as_fake();
1331    assert_eq!(model.id().0, "fake", "should return default model");
1332
1333    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1334    cx.run_until_parked();
1335    model.send_last_completion_stream_text_chunk("def");
1336    cx.run_until_parked();
1337    acp_thread.read_with(cx, |thread, cx| {
1338        assert_eq!(
1339            thread.to_markdown(cx),
1340            indoc! {"
1341                ## User
1342
1343                abc
1344
1345                ## Assistant
1346
1347                def
1348
1349            "}
1350        )
1351    });
1352
1353    // Test cancel
1354    cx.update(|cx| connection.cancel(&session_id, cx));
1355    request.await.expect("prompt should fail gracefully");
1356
1357    // Ensure that dropping the ACP thread causes the native thread to be
1358    // dropped as well.
1359    cx.update(|_| drop(acp_thread));
1360    let result = cx
1361        .update(|cx| {
1362            connection.prompt(
1363                Some(acp_thread::UserMessageId::new()),
1364                acp::PromptRequest {
1365                    session_id: session_id.clone(),
1366                    prompt: vec!["ghi".into()],
1367                },
1368                cx,
1369            )
1370        })
1371        .await;
1372    assert_eq!(
1373        result.as_ref().unwrap_err().to_string(),
1374        "Session not found",
1375        "unexpected result: {:?}",
1376        result
1377    );
1378}
1379
1380#[gpui::test]
1381async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1382    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1383    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1384    let fake_model = model.as_fake();
1385
1386    let mut events = thread
1387        .update(cx, |thread, cx| {
1388            thread.send(UserMessageId::new(), ["Think"], cx)
1389        })
1390        .unwrap();
1391    cx.run_until_parked();
1392
1393    // Simulate streaming partial input.
1394    let input = json!({});
1395    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1396        LanguageModelToolUse {
1397            id: "1".into(),
1398            name: ThinkingTool.name().into(),
1399            raw_input: input.to_string(),
1400            input,
1401            is_input_complete: false,
1402        },
1403    ));
1404
1405    // Input streaming completed
1406    let input = json!({ "content": "Thinking hard!" });
1407    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1408        LanguageModelToolUse {
1409            id: "1".into(),
1410            name: "thinking".into(),
1411            raw_input: input.to_string(),
1412            input,
1413            is_input_complete: true,
1414        },
1415    ));
1416    fake_model.end_last_completion_stream();
1417    cx.run_until_parked();
1418
1419    let tool_call = expect_tool_call(&mut events).await;
1420    assert_eq!(
1421        tool_call,
1422        acp::ToolCall {
1423            id: acp::ToolCallId("1".into()),
1424            title: "Thinking".into(),
1425            kind: acp::ToolKind::Think,
1426            status: acp::ToolCallStatus::Pending,
1427            content: vec![],
1428            locations: vec![],
1429            raw_input: Some(json!({})),
1430            raw_output: None,
1431        }
1432    );
1433    let update = expect_tool_call_update_fields(&mut events).await;
1434    assert_eq!(
1435        update,
1436        acp::ToolCallUpdate {
1437            id: acp::ToolCallId("1".into()),
1438            fields: acp::ToolCallUpdateFields {
1439                title: Some("Thinking".into()),
1440                kind: Some(acp::ToolKind::Think),
1441                raw_input: Some(json!({ "content": "Thinking hard!" })),
1442                ..Default::default()
1443            },
1444        }
1445    );
1446    let update = expect_tool_call_update_fields(&mut events).await;
1447    assert_eq!(
1448        update,
1449        acp::ToolCallUpdate {
1450            id: acp::ToolCallId("1".into()),
1451            fields: acp::ToolCallUpdateFields {
1452                status: Some(acp::ToolCallStatus::InProgress),
1453                ..Default::default()
1454            },
1455        }
1456    );
1457    let update = expect_tool_call_update_fields(&mut events).await;
1458    assert_eq!(
1459        update,
1460        acp::ToolCallUpdate {
1461            id: acp::ToolCallId("1".into()),
1462            fields: acp::ToolCallUpdateFields {
1463                content: Some(vec!["Thinking hard!".into()]),
1464                ..Default::default()
1465            },
1466        }
1467    );
1468    let update = expect_tool_call_update_fields(&mut events).await;
1469    assert_eq!(
1470        update,
1471        acp::ToolCallUpdate {
1472            id: acp::ToolCallId("1".into()),
1473            fields: acp::ToolCallUpdateFields {
1474                status: Some(acp::ToolCallStatus::Completed),
1475                raw_output: Some("Finished thinking.".into()),
1476                ..Default::default()
1477            },
1478        }
1479    );
1480}
1481
1482#[gpui::test]
1483async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1484    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1485    let fake_model = model.as_fake();
1486
1487    let mut events = thread
1488        .update(cx, |thread, cx| {
1489            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1490            thread.send(UserMessageId::new(), ["Hello!"], cx)
1491        })
1492        .unwrap();
1493    cx.run_until_parked();
1494
1495    fake_model.send_last_completion_stream_text_chunk("Hey!");
1496    fake_model.end_last_completion_stream();
1497
1498    let mut retry_events = Vec::new();
1499    while let Some(Ok(event)) = events.next().await {
1500        match event {
1501            ThreadEvent::Retry(retry_status) => {
1502                retry_events.push(retry_status);
1503            }
1504            ThreadEvent::Stop(..) => break,
1505            _ => {}
1506        }
1507    }
1508
1509    assert_eq!(retry_events.len(), 0);
1510    thread.read_with(cx, |thread, _cx| {
1511        assert_eq!(
1512            thread.to_markdown(),
1513            indoc! {"
1514                ## User
1515
1516                Hello!
1517
1518                ## Assistant
1519
1520                Hey!
1521            "}
1522        )
1523    });
1524}
1525
1526#[gpui::test]
1527async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1528    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1529    let fake_model = model.as_fake();
1530
1531    let mut events = thread
1532        .update(cx, |thread, cx| {
1533            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1534            thread.send(UserMessageId::new(), ["Hello!"], cx)
1535        })
1536        .unwrap();
1537    cx.run_until_parked();
1538
1539    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1540        provider: LanguageModelProviderName::new("Anthropic"),
1541        retry_after: Some(Duration::from_secs(3)),
1542    });
1543    fake_model.end_last_completion_stream();
1544
1545    cx.executor().advance_clock(Duration::from_secs(3));
1546    cx.run_until_parked();
1547
1548    fake_model.send_last_completion_stream_text_chunk("Hey!");
1549    fake_model.end_last_completion_stream();
1550
1551    let mut retry_events = Vec::new();
1552    while let Some(Ok(event)) = events.next().await {
1553        match event {
1554            ThreadEvent::Retry(retry_status) => {
1555                retry_events.push(retry_status);
1556            }
1557            ThreadEvent::Stop(..) => break,
1558            _ => {}
1559        }
1560    }
1561
1562    assert_eq!(retry_events.len(), 1);
1563    assert!(matches!(
1564        retry_events[0],
1565        acp_thread::RetryStatus { attempt: 1, .. }
1566    ));
1567    thread.read_with(cx, |thread, _cx| {
1568        assert_eq!(
1569            thread.to_markdown(),
1570            indoc! {"
1571                ## User
1572
1573                Hello!
1574
1575                ## Assistant
1576
1577                Hey!
1578            "}
1579        )
1580    });
1581}
1582
1583#[gpui::test]
1584async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1585    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1586    let fake_model = model.as_fake();
1587
1588    let mut events = thread
1589        .update(cx, |thread, cx| {
1590            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1591            thread.send(UserMessageId::new(), ["Hello!"], cx)
1592        })
1593        .unwrap();
1594    cx.run_until_parked();
1595
1596    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1597        fake_model.send_last_completion_stream_error(
1598            LanguageModelCompletionError::ServerOverloaded {
1599                provider: LanguageModelProviderName::new("Anthropic"),
1600                retry_after: Some(Duration::from_secs(3)),
1601            },
1602        );
1603        fake_model.end_last_completion_stream();
1604        cx.executor().advance_clock(Duration::from_secs(3));
1605        cx.run_until_parked();
1606    }
1607
1608    let mut errors = Vec::new();
1609    let mut retry_events = Vec::new();
1610    while let Some(event) = events.next().await {
1611        match event {
1612            Ok(ThreadEvent::Retry(retry_status)) => {
1613                retry_events.push(retry_status);
1614            }
1615            Ok(ThreadEvent::Stop(..)) => break,
1616            Err(error) => errors.push(error),
1617            _ => {}
1618        }
1619    }
1620
1621    assert_eq!(
1622        retry_events.len(),
1623        crate::thread::MAX_RETRY_ATTEMPTS as usize
1624    );
1625    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1626        assert_eq!(retry_events[i].attempt, i + 1);
1627    }
1628    assert_eq!(errors.len(), 1);
1629    let error = errors[0]
1630        .downcast_ref::<LanguageModelCompletionError>()
1631        .unwrap();
1632    assert!(matches!(
1633        error,
1634        LanguageModelCompletionError::ServerOverloaded { .. }
1635    ));
1636}
1637
1638/// Filters out the stop events for asserting against in tests
1639fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1640    result_events
1641        .into_iter()
1642        .filter_map(|event| match event.unwrap() {
1643            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1644            _ => None,
1645        })
1646        .collect()
1647}
1648
1649struct ThreadTest {
1650    model: Arc<dyn LanguageModel>,
1651    thread: Entity<Thread>,
1652    project_context: Entity<ProjectContext>,
1653    fs: Arc<FakeFs>,
1654}
1655
1656enum TestModel {
1657    Sonnet4,
1658    Sonnet4Thinking,
1659    Fake,
1660}
1661
1662impl TestModel {
1663    fn id(&self) -> LanguageModelId {
1664        match self {
1665            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1666            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1667            TestModel::Fake => unreachable!(),
1668        }
1669    }
1670}
1671
1672async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1673    cx.executor().allow_parking();
1674
1675    let fs = FakeFs::new(cx.background_executor.clone());
1676    fs.create_dir(paths::settings_file().parent().unwrap())
1677        .await
1678        .unwrap();
1679    fs.insert_file(
1680        paths::settings_file(),
1681        json!({
1682            "agent": {
1683                "default_profile": "test-profile",
1684                "profiles": {
1685                    "test-profile": {
1686                        "name": "Test Profile",
1687                        "tools": {
1688                            EchoTool.name(): true,
1689                            DelayTool.name(): true,
1690                            WordListTool.name(): true,
1691                            ToolRequiringPermission.name(): true,
1692                            InfiniteTool.name(): true,
1693                        }
1694                    }
1695                }
1696            }
1697        })
1698        .to_string()
1699        .into_bytes(),
1700    )
1701    .await;
1702
1703    cx.update(|cx| {
1704        settings::init(cx);
1705        Project::init_settings(cx);
1706        agent_settings::init(cx);
1707        gpui_tokio::init(cx);
1708        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1709        cx.set_http_client(Arc::new(http_client));
1710
1711        client::init_settings(cx);
1712        let client = Client::production(cx);
1713        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1714        language_model::init(client.clone(), cx);
1715        language_models::init(user_store.clone(), client.clone(), cx);
1716
1717        watch_settings(fs.clone(), cx);
1718    });
1719
1720    let templates = Templates::new();
1721
1722    fs.insert_tree(path!("/test"), json!({})).await;
1723    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1724
1725    let model = cx
1726        .update(|cx| {
1727            if let TestModel::Fake = model {
1728                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1729            } else {
1730                let model_id = model.id();
1731                let models = LanguageModelRegistry::read_global(cx);
1732                let model = models
1733                    .available_models(cx)
1734                    .find(|model| model.id() == model_id)
1735                    .unwrap();
1736
1737                let provider = models.provider(&model.provider_id()).unwrap();
1738                let authenticated = provider.authenticate(cx);
1739
1740                cx.spawn(async move |_cx| {
1741                    authenticated.await.unwrap();
1742                    model
1743                })
1744            }
1745        })
1746        .await;
1747
1748    let project_context = cx.new(|_cx| ProjectContext::default());
1749    let context_server_registry =
1750        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1751    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1752    let thread = cx.new(|cx| {
1753        Thread::new(
1754            project,
1755            project_context.clone(),
1756            context_server_registry,
1757            action_log,
1758            templates,
1759            Some(model.clone()),
1760            cx,
1761        )
1762    });
1763    ThreadTest {
1764        model,
1765        thread,
1766        project_context,
1767        fs,
1768    }
1769}
1770
1771#[cfg(test)]
1772#[ctor::ctor]
1773fn init_logger() {
1774    if std::env::var("RUST_LOG").is_ok() {
1775        env_logger::init();
1776    }
1777}
1778
1779fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1780    let fs = fs.clone();
1781    cx.spawn({
1782        async move |cx| {
1783            let mut new_settings_content_rx = settings::watch_config_file(
1784                cx.background_executor(),
1785                fs,
1786                paths::settings_file().clone(),
1787            );
1788
1789            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1790                cx.update(|cx| {
1791                    SettingsStore::update_global(cx, |settings, cx| {
1792                        settings.set_user_settings(&new_settings_content, cx)
1793                    })
1794                })
1795                .ok();
1796            }
1797        }
1798    })
1799    .detach();
1800}