mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  17    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use pretty_assertions::assert_eq;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use reqwest_client::ReqwestClient;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .unwrap()
  44        .collect()
  45        .await;
  46    thread.update(cx, |thread, _cx| {
  47        assert_eq!(
  48            thread.last_message().unwrap().to_markdown(),
  49            indoc! {"
  50                ## Assistant
  51
  52                Hello
  53            "}
  54        )
  55    });
  56    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  57}
  58
  59#[gpui::test]
  60#[ignore = "can't run on CI yet"]
  61async fn test_thinking(cx: &mut TestAppContext) {
  62    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  63
  64    let events = thread
  65        .update(cx, |thread, cx| {
  66            thread.send(
  67                UserMessageId::new(),
  68                [indoc! {"
  69                    Testing:
  70
  71                    Generate a thinking step where you just think the word 'Think',
  72                    and have your final answer be 'Hello'
  73                "}],
  74                cx,
  75            )
  76        })
  77        .unwrap()
  78        .collect()
  79        .await;
  80    thread.update(cx, |thread, _cx| {
  81        assert_eq!(
  82            thread.last_message().unwrap().to_markdown(),
  83            indoc! {"
  84                ## Assistant
  85
  86                <think>Think</think>
  87                Hello
  88            "}
  89        )
  90    });
  91    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  92}
  93
  94#[gpui::test]
  95async fn test_system_prompt(cx: &mut TestAppContext) {
  96    let ThreadTest {
  97        model,
  98        thread,
  99        project_context,
 100        ..
 101    } = setup(cx, TestModel::Fake).await;
 102    let fake_model = model.as_fake();
 103
 104    project_context.update(cx, |project_context, _cx| {
 105        project_context.shell = "test-shell".into()
 106    });
 107    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 108    thread
 109        .update(cx, |thread, cx| {
 110            thread.send(UserMessageId::new(), ["abc"], cx)
 111        })
 112        .unwrap();
 113    cx.run_until_parked();
 114    let mut pending_completions = fake_model.pending_completions();
 115    assert_eq!(
 116        pending_completions.len(),
 117        1,
 118        "unexpected pending completions: {:?}",
 119        pending_completions
 120    );
 121
 122    let pending_completion = pending_completions.pop().unwrap();
 123    assert_eq!(pending_completion.messages[0].role, Role::System);
 124
 125    let system_message = &pending_completion.messages[0];
 126    let system_prompt = system_message.content[0].to_str().unwrap();
 127    assert!(
 128        system_prompt.contains("test-shell"),
 129        "unexpected system message: {:?}",
 130        system_message
 131    );
 132    assert!(
 133        system_prompt.contains("## Fixing Diagnostics"),
 134        "unexpected system message: {:?}",
 135        system_message
 136    );
 137}
 138
 139#[gpui::test]
 140async fn test_prompt_caching(cx: &mut TestAppContext) {
 141    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 142    let fake_model = model.as_fake();
 143
 144    // Send initial user message and verify it's cached
 145    thread
 146        .update(cx, |thread, cx| {
 147            thread.send(UserMessageId::new(), ["Message 1"], cx)
 148        })
 149        .unwrap();
 150    cx.run_until_parked();
 151
 152    let completion = fake_model.pending_completions().pop().unwrap();
 153    assert_eq!(
 154        completion.messages[1..],
 155        vec![LanguageModelRequestMessage {
 156            role: Role::User,
 157            content: vec!["Message 1".into()],
 158            cache: true
 159        }]
 160    );
 161    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 162        "Response to Message 1".into(),
 163    ));
 164    fake_model.end_last_completion_stream();
 165    cx.run_until_parked();
 166
 167    // Send another user message and verify only the latest is cached
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["Message 2"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174
 175    let completion = fake_model.pending_completions().pop().unwrap();
 176    assert_eq!(
 177        completion.messages[1..],
 178        vec![
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 1".into()],
 182                cache: false
 183            },
 184            LanguageModelRequestMessage {
 185                role: Role::Assistant,
 186                content: vec!["Response to Message 1".into()],
 187                cache: false
 188            },
 189            LanguageModelRequestMessage {
 190                role: Role::User,
 191                content: vec!["Message 2".into()],
 192                cache: true
 193            }
 194        ]
 195    );
 196    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 197        "Response to Message 2".into(),
 198    ));
 199    fake_model.end_last_completion_stream();
 200    cx.run_until_parked();
 201
 202    // Simulate a tool call and verify that the latest tool result is cached
 203    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 204    thread
 205        .update(cx, |thread, cx| {
 206            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 207        })
 208        .unwrap();
 209    cx.run_until_parked();
 210
 211    let tool_use = LanguageModelToolUse {
 212        id: "tool_1".into(),
 213        name: EchoTool.name().into(),
 214        raw_input: json!({"text": "test"}).to_string(),
 215        input: json!({"text": "test"}),
 216        is_input_complete: true,
 217    };
 218    fake_model
 219        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 220    fake_model.end_last_completion_stream();
 221    cx.run_until_parked();
 222
 223    let completion = fake_model.pending_completions().pop().unwrap();
 224    let tool_result = LanguageModelToolResult {
 225        tool_use_id: "tool_1".into(),
 226        tool_name: EchoTool.name().into(),
 227        is_error: false,
 228        content: "test".into(),
 229        output: Some("test".into()),
 230    };
 231    assert_eq!(
 232        completion.messages[1..],
 233        vec![
 234            LanguageModelRequestMessage {
 235                role: Role::User,
 236                content: vec!["Message 1".into()],
 237                cache: false
 238            },
 239            LanguageModelRequestMessage {
 240                role: Role::Assistant,
 241                content: vec!["Response to Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::User,
 246                content: vec!["Message 2".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::Assistant,
 251                content: vec!["Response to Message 2".into()],
 252                cache: false
 253            },
 254            LanguageModelRequestMessage {
 255                role: Role::User,
 256                content: vec!["Use the echo tool".into()],
 257                cache: false
 258            },
 259            LanguageModelRequestMessage {
 260                role: Role::Assistant,
 261                content: vec![MessageContent::ToolUse(tool_use)],
 262                cache: false
 263            },
 264            LanguageModelRequestMessage {
 265                role: Role::User,
 266                content: vec![MessageContent::ToolResult(tool_result)],
 267                cache: true
 268            }
 269        ]
 270    );
 271}
 272
 273#[gpui::test]
 274#[ignore = "can't run on CI yet"]
 275async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 276    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 277
 278    // Test a tool call that's likely to complete *before* streaming stops.
 279    let events = thread
 280        .update(cx, |thread, cx| {
 281            thread.add_tool(EchoTool);
 282            thread.send(
 283                UserMessageId::new(),
 284                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 285                cx,
 286            )
 287        })
 288        .unwrap()
 289        .collect()
 290        .await;
 291    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 292
 293    // Test a tool calls that's likely to complete *after* streaming stops.
 294    let events = thread
 295        .update(cx, |thread, cx| {
 296            thread.remove_tool(&AgentTool::name(&EchoTool));
 297            thread.add_tool(DelayTool);
 298            thread.send(
 299                UserMessageId::new(),
 300                [
 301                    "Now call the delay tool with 200ms.",
 302                    "When the timer goes off, then you echo the output of the tool.",
 303                ],
 304                cx,
 305            )
 306        })
 307        .unwrap()
 308        .collect()
 309        .await;
 310    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 311    thread.update(cx, |thread, _cx| {
 312        assert!(
 313            thread
 314                .last_message()
 315                .unwrap()
 316                .as_agent_message()
 317                .unwrap()
 318                .content
 319                .iter()
 320                .any(|content| {
 321                    if let AgentMessageContent::Text(text) = content {
 322                        text.contains("Ding")
 323                    } else {
 324                        false
 325                    }
 326                }),
 327            "{}",
 328            thread.to_markdown()
 329        );
 330    });
 331}
 332
 333#[gpui::test]
 334#[ignore = "can't run on CI yet"]
 335async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let mut events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(WordListTool);
 342            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 343        })
 344        .unwrap();
 345
 346    let mut saw_partial_tool_use = false;
 347    while let Some(event) = events.next().await {
 348        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
 349            thread.update(cx, |thread, _cx| {
 350                // Look for a tool use in the thread's last message
 351                let message = thread.last_message().unwrap();
 352                let agent_message = message.as_agent_message().unwrap();
 353                let last_content = agent_message.content.last().unwrap();
 354                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 355                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 356                    if tool_call.status == acp::ToolCallStatus::Pending {
 357                        if !last_tool_use.is_input_complete
 358                            && last_tool_use.input.get("g").is_none()
 359                        {
 360                            saw_partial_tool_use = true;
 361                        }
 362                    } else {
 363                        last_tool_use
 364                            .input
 365                            .get("a")
 366                            .expect("'a' has streamed because input is now complete");
 367                        last_tool_use
 368                            .input
 369                            .get("g")
 370                            .expect("'g' has streamed because input is now complete");
 371                    }
 372                } else {
 373                    panic!("last content should be a tool use");
 374                }
 375            });
 376        }
 377    }
 378
 379    assert!(
 380        saw_partial_tool_use,
 381        "should see at least one partially streamed tool use in the history"
 382    );
 383}
 384
 385#[gpui::test]
 386async fn test_tool_authorization(cx: &mut TestAppContext) {
 387    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 388    let fake_model = model.as_fake();
 389
 390    let mut events = thread
 391        .update(cx, |thread, cx| {
 392            thread.add_tool(ToolRequiringPermission);
 393            thread.send(UserMessageId::new(), ["abc"], cx)
 394        })
 395        .unwrap();
 396    cx.run_until_parked();
 397    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 398        LanguageModelToolUse {
 399            id: "tool_id_1".into(),
 400            name: ToolRequiringPermission.name().into(),
 401            raw_input: "{}".into(),
 402            input: json!({}),
 403            is_input_complete: true,
 404        },
 405    ));
 406    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 407        LanguageModelToolUse {
 408            id: "tool_id_2".into(),
 409            name: ToolRequiringPermission.name().into(),
 410            raw_input: "{}".into(),
 411            input: json!({}),
 412            is_input_complete: true,
 413        },
 414    ));
 415    fake_model.end_last_completion_stream();
 416    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 417    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 418
 419    // Approve the first
 420    tool_call_auth_1
 421        .response
 422        .send(tool_call_auth_1.options[1].id.clone())
 423        .unwrap();
 424    cx.run_until_parked();
 425
 426    // Reject the second
 427    tool_call_auth_2
 428        .response
 429        .send(tool_call_auth_1.options[2].id.clone())
 430        .unwrap();
 431    cx.run_until_parked();
 432
 433    let completion = fake_model.pending_completions().pop().unwrap();
 434    let message = completion.messages.last().unwrap();
 435    assert_eq!(
 436        message.content,
 437        vec![
 438            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 439                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 440                tool_name: ToolRequiringPermission.name().into(),
 441                is_error: false,
 442                content: "Allowed".into(),
 443                output: Some("Allowed".into())
 444            }),
 445            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 446                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 447                tool_name: ToolRequiringPermission.name().into(),
 448                is_error: true,
 449                content: "Permission to run tool denied by user".into(),
 450                output: None
 451            })
 452        ]
 453    );
 454
 455    // Simulate yet another tool call.
 456    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 457        LanguageModelToolUse {
 458            id: "tool_id_3".into(),
 459            name: ToolRequiringPermission.name().into(),
 460            raw_input: "{}".into(),
 461            input: json!({}),
 462            is_input_complete: true,
 463        },
 464    ));
 465    fake_model.end_last_completion_stream();
 466
 467    // Respond by always allowing tools.
 468    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 469    tool_call_auth_3
 470        .response
 471        .send(tool_call_auth_3.options[0].id.clone())
 472        .unwrap();
 473    cx.run_until_parked();
 474    let completion = fake_model.pending_completions().pop().unwrap();
 475    let message = completion.messages.last().unwrap();
 476    assert_eq!(
 477        message.content,
 478        vec![language_model::MessageContent::ToolResult(
 479            LanguageModelToolResult {
 480                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 481                tool_name: ToolRequiringPermission.name().into(),
 482                is_error: false,
 483                content: "Allowed".into(),
 484                output: Some("Allowed".into())
 485            }
 486        )]
 487    );
 488
 489    // Simulate a final tool call, ensuring we don't trigger authorization.
 490    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 491        LanguageModelToolUse {
 492            id: "tool_id_4".into(),
 493            name: ToolRequiringPermission.name().into(),
 494            raw_input: "{}".into(),
 495            input: json!({}),
 496            is_input_complete: true,
 497        },
 498    ));
 499    fake_model.end_last_completion_stream();
 500    cx.run_until_parked();
 501    let completion = fake_model.pending_completions().pop().unwrap();
 502    let message = completion.messages.last().unwrap();
 503    assert_eq!(
 504        message.content,
 505        vec![language_model::MessageContent::ToolResult(
 506            LanguageModelToolResult {
 507                tool_use_id: "tool_id_4".into(),
 508                tool_name: ToolRequiringPermission.name().into(),
 509                is_error: false,
 510                content: "Allowed".into(),
 511                output: Some("Allowed".into())
 512            }
 513        )]
 514    );
 515}
 516
 517#[gpui::test]
 518async fn test_tool_hallucination(cx: &mut TestAppContext) {
 519    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 520    let fake_model = model.as_fake();
 521
 522    let mut events = thread
 523        .update(cx, |thread, cx| {
 524            thread.send(UserMessageId::new(), ["abc"], cx)
 525        })
 526        .unwrap();
 527    cx.run_until_parked();
 528    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 529        LanguageModelToolUse {
 530            id: "tool_id_1".into(),
 531            name: "nonexistent_tool".into(),
 532            raw_input: "{}".into(),
 533            input: json!({}),
 534            is_input_complete: true,
 535        },
 536    ));
 537    fake_model.end_last_completion_stream();
 538
 539    let tool_call = expect_tool_call(&mut events).await;
 540    assert_eq!(tool_call.title, "nonexistent_tool");
 541    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 542    let update = expect_tool_call_update_fields(&mut events).await;
 543    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 544}
 545
 546#[gpui::test]
 547async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 548    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 549    let fake_model = model.as_fake();
 550
 551    let events = thread
 552        .update(cx, |thread, cx| {
 553            thread.add_tool(EchoTool);
 554            thread.send(UserMessageId::new(), ["abc"], cx)
 555        })
 556        .unwrap();
 557    cx.run_until_parked();
 558    let tool_use = LanguageModelToolUse {
 559        id: "tool_id_1".into(),
 560        name: EchoTool.name().into(),
 561        raw_input: "{}".into(),
 562        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 563        is_input_complete: true,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568
 569    cx.run_until_parked();
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_id_1".into(),
 573        tool_name: EchoTool.name().into(),
 574        is_error: false,
 575        content: "def".into(),
 576        output: Some("def".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["abc".into()],
 584                cache: false
 585            },
 586            LanguageModelRequestMessage {
 587                role: Role::Assistant,
 588                content: vec![MessageContent::ToolUse(tool_use.clone())],
 589                cache: false
 590            },
 591            LanguageModelRequestMessage {
 592                role: Role::User,
 593                content: vec![MessageContent::ToolResult(tool_result.clone())],
 594                cache: true
 595            },
 596        ]
 597    );
 598
 599    // Simulate reaching tool use limit.
 600    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 601        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 602    ));
 603    fake_model.end_last_completion_stream();
 604    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 605    assert!(
 606        last_event
 607            .unwrap_err()
 608            .is::<language_model::ToolUseLimitReachedError>()
 609    );
 610
 611    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 612    cx.run_until_parked();
 613    let completion = fake_model.pending_completions().pop().unwrap();
 614    assert_eq!(
 615        completion.messages[1..],
 616        vec![
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec!["abc".into()],
 620                cache: false
 621            },
 622            LanguageModelRequestMessage {
 623                role: Role::Assistant,
 624                content: vec![MessageContent::ToolUse(tool_use)],
 625                cache: false
 626            },
 627            LanguageModelRequestMessage {
 628                role: Role::User,
 629                content: vec![MessageContent::ToolResult(tool_result)],
 630                cache: false
 631            },
 632            LanguageModelRequestMessage {
 633                role: Role::User,
 634                content: vec!["Continue where you left off".into()],
 635                cache: true
 636            }
 637        ]
 638    );
 639
 640    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 641    fake_model.end_last_completion_stream();
 642    events.collect::<Vec<_>>().await;
 643    thread.read_with(cx, |thread, _cx| {
 644        assert_eq!(
 645            thread.last_message().unwrap().to_markdown(),
 646            indoc! {"
 647                ## Assistant
 648
 649                Done
 650            "}
 651        )
 652    });
 653
 654    // Ensure we error if calling resume when tool use limit was *not* reached.
 655    let error = thread
 656        .update(cx, |thread, cx| thread.resume(cx))
 657        .unwrap_err();
 658    assert_eq!(
 659        error.to_string(),
 660        "can only resume after tool use limit is reached"
 661    )
 662}
 663
 664#[gpui::test]
 665async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 666    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 667    let fake_model = model.as_fake();
 668
 669    let events = thread
 670        .update(cx, |thread, cx| {
 671            thread.add_tool(EchoTool);
 672            thread.send(UserMessageId::new(), ["abc"], cx)
 673        })
 674        .unwrap();
 675    cx.run_until_parked();
 676
 677    let tool_use = LanguageModelToolUse {
 678        id: "tool_id_1".into(),
 679        name: EchoTool.name().into(),
 680        raw_input: "{}".into(),
 681        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 682        is_input_complete: true,
 683    };
 684    let tool_result = LanguageModelToolResult {
 685        tool_use_id: "tool_id_1".into(),
 686        tool_name: EchoTool.name().into(),
 687        is_error: false,
 688        content: "def".into(),
 689        output: Some("def".into()),
 690    };
 691    fake_model
 692        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 693    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 694        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 695    ));
 696    fake_model.end_last_completion_stream();
 697    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 698    assert!(
 699        last_event
 700            .unwrap_err()
 701            .is::<language_model::ToolUseLimitReachedError>()
 702    );
 703
 704    thread
 705        .update(cx, |thread, cx| {
 706            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 707        })
 708        .unwrap();
 709    cx.run_until_parked();
 710    let completion = fake_model.pending_completions().pop().unwrap();
 711    assert_eq!(
 712        completion.messages[1..],
 713        vec![
 714            LanguageModelRequestMessage {
 715                role: Role::User,
 716                content: vec!["abc".into()],
 717                cache: false
 718            },
 719            LanguageModelRequestMessage {
 720                role: Role::Assistant,
 721                content: vec![MessageContent::ToolUse(tool_use)],
 722                cache: false
 723            },
 724            LanguageModelRequestMessage {
 725                role: Role::User,
 726                content: vec![MessageContent::ToolResult(tool_result)],
 727                cache: false
 728            },
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["ghi".into()],
 732                cache: true
 733            }
 734        ]
 735    );
 736}
 737
 738async fn expect_tool_call(
 739    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 740) -> acp::ToolCall {
 741    let event = events
 742        .next()
 743        .await
 744        .expect("no tool call authorization event received")
 745        .unwrap();
 746    match event {
 747        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
 748        event => {
 749            panic!("Unexpected event {event:?}");
 750        }
 751    }
 752}
 753
 754async fn expect_tool_call_update_fields(
 755    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 756) -> acp::ToolCallUpdate {
 757    let event = events
 758        .next()
 759        .await
 760        .expect("no tool call authorization event received")
 761        .unwrap();
 762    match event {
 763        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 764            return update;
 765        }
 766        event => {
 767            panic!("Unexpected event {event:?}");
 768        }
 769    }
 770}
 771
 772async fn next_tool_call_authorization(
 773    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 774) -> ToolCallAuthorization {
 775    loop {
 776        let event = events
 777            .next()
 778            .await
 779            .expect("no tool call authorization event received")
 780            .unwrap();
 781        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
 782            let permission_kinds = tool_call_authorization
 783                .options
 784                .iter()
 785                .map(|o| o.kind)
 786                .collect::<Vec<_>>();
 787            assert_eq!(
 788                permission_kinds,
 789                vec![
 790                    acp::PermissionOptionKind::AllowAlways,
 791                    acp::PermissionOptionKind::AllowOnce,
 792                    acp::PermissionOptionKind::RejectOnce,
 793                ]
 794            );
 795            return tool_call_authorization;
 796        }
 797    }
 798}
 799
 800#[gpui::test]
 801#[ignore = "can't run on CI yet"]
 802async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 803    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 804
 805    // Test concurrent tool calls with different delay times
 806    let events = thread
 807        .update(cx, |thread, cx| {
 808            thread.add_tool(DelayTool);
 809            thread.send(
 810                UserMessageId::new(),
 811                [
 812                    "Call the delay tool twice in the same message.",
 813                    "Once with 100ms. Once with 300ms.",
 814                    "When both timers are complete, describe the outputs.",
 815                ],
 816                cx,
 817            )
 818        })
 819        .unwrap()
 820        .collect()
 821        .await;
 822
 823    let stop_reasons = stop_events(events);
 824    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 825
 826    thread.update(cx, |thread, _cx| {
 827        let last_message = thread.last_message().unwrap();
 828        let agent_message = last_message.as_agent_message().unwrap();
 829        let text = agent_message
 830            .content
 831            .iter()
 832            .filter_map(|content| {
 833                if let AgentMessageContent::Text(text) = content {
 834                    Some(text.as_str())
 835                } else {
 836                    None
 837                }
 838            })
 839            .collect::<String>();
 840
 841        assert!(text.contains("Ding"));
 842    });
 843}
 844
 845#[gpui::test]
 846async fn test_profiles(cx: &mut TestAppContext) {
 847    let ThreadTest {
 848        model, thread, fs, ..
 849    } = setup(cx, TestModel::Fake).await;
 850    let fake_model = model.as_fake();
 851
 852    thread.update(cx, |thread, _cx| {
 853        thread.add_tool(DelayTool);
 854        thread.add_tool(EchoTool);
 855        thread.add_tool(InfiniteTool);
 856    });
 857
 858    // Override profiles and wait for settings to be loaded.
 859    fs.insert_file(
 860        paths::settings_file(),
 861        json!({
 862            "agent": {
 863                "profiles": {
 864                    "test-1": {
 865                        "name": "Test Profile 1",
 866                        "tools": {
 867                            EchoTool.name(): true,
 868                            DelayTool.name(): true,
 869                        }
 870                    },
 871                    "test-2": {
 872                        "name": "Test Profile 2",
 873                        "tools": {
 874                            InfiniteTool.name(): true,
 875                        }
 876                    }
 877                }
 878            }
 879        })
 880        .to_string()
 881        .into_bytes(),
 882    )
 883    .await;
 884    cx.run_until_parked();
 885
 886    // Test that test-1 profile (default) has echo and delay tools
 887    thread
 888        .update(cx, |thread, cx| {
 889            thread.set_profile(AgentProfileId("test-1".into()));
 890            thread.send(UserMessageId::new(), ["test"], cx)
 891        })
 892        .unwrap();
 893    cx.run_until_parked();
 894
 895    let mut pending_completions = fake_model.pending_completions();
 896    assert_eq!(pending_completions.len(), 1);
 897    let completion = pending_completions.pop().unwrap();
 898    let tool_names: Vec<String> = completion
 899        .tools
 900        .iter()
 901        .map(|tool| tool.name.clone())
 902        .collect();
 903    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 904    fake_model.end_last_completion_stream();
 905
 906    // Switch to test-2 profile, and verify that it has only the infinite tool.
 907    thread
 908        .update(cx, |thread, cx| {
 909            thread.set_profile(AgentProfileId("test-2".into()));
 910            thread.send(UserMessageId::new(), ["test2"], cx)
 911        })
 912        .unwrap();
 913    cx.run_until_parked();
 914    let mut pending_completions = fake_model.pending_completions();
 915    assert_eq!(pending_completions.len(), 1);
 916    let completion = pending_completions.pop().unwrap();
 917    let tool_names: Vec<String> = completion
 918        .tools
 919        .iter()
 920        .map(|tool| tool.name.clone())
 921        .collect();
 922    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 923}
 924
 925#[gpui::test]
 926#[ignore = "can't run on CI yet"]
 927async fn test_cancellation(cx: &mut TestAppContext) {
 928    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 929
 930    let mut events = thread
 931        .update(cx, |thread, cx| {
 932            thread.add_tool(InfiniteTool);
 933            thread.add_tool(EchoTool);
 934            thread.send(
 935                UserMessageId::new(),
 936                ["Call the echo tool, then call the infinite tool, then explain their output"],
 937                cx,
 938            )
 939        })
 940        .unwrap();
 941
 942    // Wait until both tools are called.
 943    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 944    let mut echo_id = None;
 945    let mut echo_completed = false;
 946    while let Some(event) = events.next().await {
 947        match event.unwrap() {
 948            AgentResponseEvent::ToolCall(tool_call) => {
 949                assert_eq!(tool_call.title, expected_tools.remove(0));
 950                if tool_call.title == "Echo" {
 951                    echo_id = Some(tool_call.id);
 952                }
 953            }
 954            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 955                acp::ToolCallUpdate {
 956                    id,
 957                    fields:
 958                        acp::ToolCallUpdateFields {
 959                            status: Some(acp::ToolCallStatus::Completed),
 960                            ..
 961                        },
 962                },
 963            )) if Some(&id) == echo_id.as_ref() => {
 964                echo_completed = true;
 965            }
 966            _ => {}
 967        }
 968
 969        if expected_tools.is_empty() && echo_completed {
 970            break;
 971        }
 972    }
 973
 974    // Cancel the current send and ensure that the event stream is closed, even
 975    // if one of the tools is still running.
 976    thread.update(cx, |thread, _cx| thread.cancel());
 977    let events = events.collect::<Vec<_>>().await;
 978    let last_event = events.last();
 979    assert!(
 980        matches!(
 981            last_event,
 982            Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
 983        ),
 984        "unexpected event {last_event:?}"
 985    );
 986
 987    // Ensure we can still send a new message after cancellation.
 988    let events = thread
 989        .update(cx, |thread, cx| {
 990            thread.send(
 991                UserMessageId::new(),
 992                ["Testing: reply with 'Hello' then stop."],
 993                cx,
 994            )
 995        })
 996        .unwrap()
 997        .collect::<Vec<_>>()
 998        .await;
 999    thread.update(cx, |thread, _cx| {
1000        let message = thread.last_message().unwrap();
1001        let agent_message = message.as_agent_message().unwrap();
1002        assert_eq!(
1003            agent_message.content,
1004            vec![AgentMessageContent::Text("Hello".to_string())]
1005        );
1006    });
1007    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1008}
1009
1010#[gpui::test]
1011async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1012    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1013    let fake_model = model.as_fake();
1014
1015    let events_1 = thread
1016        .update(cx, |thread, cx| {
1017            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1018        })
1019        .unwrap();
1020    cx.run_until_parked();
1021    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1022    cx.run_until_parked();
1023
1024    let events_2 = thread
1025        .update(cx, |thread, cx| {
1026            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1027        })
1028        .unwrap();
1029    cx.run_until_parked();
1030    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1031    fake_model
1032        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1033    fake_model.end_last_completion_stream();
1034
1035    let events_1 = events_1.collect::<Vec<_>>().await;
1036    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
1037    let events_2 = events_2.collect::<Vec<_>>().await;
1038    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1039}
1040
1041#[gpui::test]
1042async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1043    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1044    let fake_model = model.as_fake();
1045
1046    let events_1 = thread
1047        .update(cx, |thread, cx| {
1048            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1049        })
1050        .unwrap();
1051    cx.run_until_parked();
1052    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1053    fake_model
1054        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1055    fake_model.end_last_completion_stream();
1056    let events_1 = events_1.collect::<Vec<_>>().await;
1057
1058    let events_2 = thread
1059        .update(cx, |thread, cx| {
1060            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1061        })
1062        .unwrap();
1063    cx.run_until_parked();
1064    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1065    fake_model
1066        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1067    fake_model.end_last_completion_stream();
1068    let events_2 = events_2.collect::<Vec<_>>().await;
1069
1070    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1071    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1072}
1073
1074#[gpui::test]
1075async fn test_refusal(cx: &mut TestAppContext) {
1076    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1077    let fake_model = model.as_fake();
1078
1079    let events = thread
1080        .update(cx, |thread, cx| {
1081            thread.send(UserMessageId::new(), ["Hello"], cx)
1082        })
1083        .unwrap();
1084    cx.run_until_parked();
1085    thread.read_with(cx, |thread, _| {
1086        assert_eq!(
1087            thread.to_markdown(),
1088            indoc! {"
1089                ## User
1090
1091                Hello
1092            "}
1093        );
1094    });
1095
1096    fake_model.send_last_completion_stream_text_chunk("Hey!");
1097    cx.run_until_parked();
1098    thread.read_with(cx, |thread, _| {
1099        assert_eq!(
1100            thread.to_markdown(),
1101            indoc! {"
1102                ## User
1103
1104                Hello
1105
1106                ## Assistant
1107
1108                Hey!
1109            "}
1110        );
1111    });
1112
1113    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1114    fake_model
1115        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1116    let events = events.collect::<Vec<_>>().await;
1117    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1118    thread.read_with(cx, |thread, _| {
1119        assert_eq!(thread.to_markdown(), "");
1120    });
1121}
1122
1123#[gpui::test]
1124async fn test_truncate(cx: &mut TestAppContext) {
1125    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1126    let fake_model = model.as_fake();
1127
1128    let message_id = UserMessageId::new();
1129    thread
1130        .update(cx, |thread, cx| {
1131            thread.send(message_id.clone(), ["Hello"], cx)
1132        })
1133        .unwrap();
1134    cx.run_until_parked();
1135    thread.read_with(cx, |thread, _| {
1136        assert_eq!(
1137            thread.to_markdown(),
1138            indoc! {"
1139                ## User
1140
1141                Hello
1142            "}
1143        );
1144    });
1145
1146    fake_model.send_last_completion_stream_text_chunk("Hey!");
1147    cx.run_until_parked();
1148    thread.read_with(cx, |thread, _| {
1149        assert_eq!(
1150            thread.to_markdown(),
1151            indoc! {"
1152                ## User
1153
1154                Hello
1155
1156                ## Assistant
1157
1158                Hey!
1159            "}
1160        );
1161    });
1162
1163    thread
1164        .update(cx, |thread, _cx| thread.truncate(message_id))
1165        .unwrap();
1166    cx.run_until_parked();
1167    thread.read_with(cx, |thread, _| {
1168        assert_eq!(thread.to_markdown(), "");
1169    });
1170
1171    // Ensure we can still send a new message after truncation.
1172    thread
1173        .update(cx, |thread, cx| {
1174            thread.send(UserMessageId::new(), ["Hi"], cx)
1175        })
1176        .unwrap();
1177    thread.update(cx, |thread, _cx| {
1178        assert_eq!(
1179            thread.to_markdown(),
1180            indoc! {"
1181                ## User
1182
1183                Hi
1184            "}
1185        );
1186    });
1187    cx.run_until_parked();
1188    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1189    cx.run_until_parked();
1190    thread.read_with(cx, |thread, _| {
1191        assert_eq!(
1192            thread.to_markdown(),
1193            indoc! {"
1194                ## User
1195
1196                Hi
1197
1198                ## Assistant
1199
1200                Ahoy!
1201            "}
1202        );
1203    });
1204}
1205
1206#[gpui::test]
1207async fn test_agent_connection(cx: &mut TestAppContext) {
1208    cx.update(settings::init);
1209    let templates = Templates::new();
1210
1211    // Initialize language model system with test provider
1212    cx.update(|cx| {
1213        gpui_tokio::init(cx);
1214        client::init_settings(cx);
1215
1216        let http_client = FakeHttpClient::with_404_response();
1217        let clock = Arc::new(clock::FakeSystemClock::new());
1218        let client = Client::new(clock, http_client, cx);
1219        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1220        language_model::init(client.clone(), cx);
1221        language_models::init(user_store.clone(), client.clone(), cx);
1222        Project::init_settings(cx);
1223        LanguageModelRegistry::test(cx);
1224        agent_settings::init(cx);
1225    });
1226    cx.executor().forbid_parking();
1227
1228    // Create a project for new_thread
1229    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1230    fake_fs.insert_tree(path!("/test"), json!({})).await;
1231    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1232    let cwd = Path::new("/test");
1233
1234    // Create agent and connection
1235    let agent = NativeAgent::new(
1236        project.clone(),
1237        templates.clone(),
1238        None,
1239        fake_fs.clone(),
1240        &mut cx.to_async(),
1241    )
1242    .await
1243    .unwrap();
1244    let connection = NativeAgentConnection(agent.clone());
1245
1246    // Test model_selector returns Some
1247    let selector_opt = connection.model_selector();
1248    assert!(
1249        selector_opt.is_some(),
1250        "agent2 should always support ModelSelector"
1251    );
1252    let selector = selector_opt.unwrap();
1253
1254    // Test list_models
1255    let listed_models = cx
1256        .update(|cx| selector.list_models(cx))
1257        .await
1258        .expect("list_models should succeed");
1259    let AgentModelList::Grouped(listed_models) = listed_models else {
1260        panic!("Unexpected model list type");
1261    };
1262    assert!(!listed_models.is_empty(), "should have at least one model");
1263    assert_eq!(
1264        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1265        "fake/fake"
1266    );
1267
1268    // Create a thread using new_thread
1269    let connection_rc = Rc::new(connection.clone());
1270    let acp_thread = cx
1271        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1272        .await
1273        .expect("new_thread should succeed");
1274
1275    // Get the session_id from the AcpThread
1276    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1277
1278    // Test selected_model returns the default
1279    let model = cx
1280        .update(|cx| selector.selected_model(&session_id, cx))
1281        .await
1282        .expect("selected_model should succeed");
1283    let model = cx
1284        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1285        .unwrap();
1286    let model = model.as_fake();
1287    assert_eq!(model.id().0, "fake", "should return default model");
1288
1289    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1290    cx.run_until_parked();
1291    model.send_last_completion_stream_text_chunk("def");
1292    cx.run_until_parked();
1293    acp_thread.read_with(cx, |thread, cx| {
1294        assert_eq!(
1295            thread.to_markdown(cx),
1296            indoc! {"
1297                ## User
1298
1299                abc
1300
1301                ## Assistant
1302
1303                def
1304
1305            "}
1306        )
1307    });
1308
1309    // Test cancel
1310    cx.update(|cx| connection.cancel(&session_id, cx));
1311    request.await.expect("prompt should fail gracefully");
1312
1313    // Ensure that dropping the ACP thread causes the native thread to be
1314    // dropped as well.
1315    cx.update(|_| drop(acp_thread));
1316    let result = cx
1317        .update(|cx| {
1318            connection.prompt(
1319                Some(acp_thread::UserMessageId::new()),
1320                acp::PromptRequest {
1321                    session_id: session_id.clone(),
1322                    prompt: vec!["ghi".into()],
1323                },
1324                cx,
1325            )
1326        })
1327        .await;
1328    assert_eq!(
1329        result.as_ref().unwrap_err().to_string(),
1330        "Session not found",
1331        "unexpected result: {:?}",
1332        result
1333    );
1334}
1335
1336#[gpui::test]
1337async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1338    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1339    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1340    let fake_model = model.as_fake();
1341
1342    let mut events = thread
1343        .update(cx, |thread, cx| {
1344            thread.send(UserMessageId::new(), ["Think"], cx)
1345        })
1346        .unwrap();
1347    cx.run_until_parked();
1348
1349    // Simulate streaming partial input.
1350    let input = json!({});
1351    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1352        LanguageModelToolUse {
1353            id: "1".into(),
1354            name: ThinkingTool.name().into(),
1355            raw_input: input.to_string(),
1356            input,
1357            is_input_complete: false,
1358        },
1359    ));
1360
1361    // Input streaming completed
1362    let input = json!({ "content": "Thinking hard!" });
1363    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1364        LanguageModelToolUse {
1365            id: "1".into(),
1366            name: "thinking".into(),
1367            raw_input: input.to_string(),
1368            input,
1369            is_input_complete: true,
1370        },
1371    ));
1372    fake_model.end_last_completion_stream();
1373    cx.run_until_parked();
1374
1375    let tool_call = expect_tool_call(&mut events).await;
1376    assert_eq!(
1377        tool_call,
1378        acp::ToolCall {
1379            id: acp::ToolCallId("1".into()),
1380            title: "Thinking".into(),
1381            kind: acp::ToolKind::Think,
1382            status: acp::ToolCallStatus::Pending,
1383            content: vec![],
1384            locations: vec![],
1385            raw_input: Some(json!({})),
1386            raw_output: None,
1387        }
1388    );
1389    let update = expect_tool_call_update_fields(&mut events).await;
1390    assert_eq!(
1391        update,
1392        acp::ToolCallUpdate {
1393            id: acp::ToolCallId("1".into()),
1394            fields: acp::ToolCallUpdateFields {
1395                title: Some("Thinking".into()),
1396                kind: Some(acp::ToolKind::Think),
1397                raw_input: Some(json!({ "content": "Thinking hard!" })),
1398                ..Default::default()
1399            },
1400        }
1401    );
1402    let update = expect_tool_call_update_fields(&mut events).await;
1403    assert_eq!(
1404        update,
1405        acp::ToolCallUpdate {
1406            id: acp::ToolCallId("1".into()),
1407            fields: acp::ToolCallUpdateFields {
1408                status: Some(acp::ToolCallStatus::InProgress),
1409                ..Default::default()
1410            },
1411        }
1412    );
1413    let update = expect_tool_call_update_fields(&mut events).await;
1414    assert_eq!(
1415        update,
1416        acp::ToolCallUpdate {
1417            id: acp::ToolCallId("1".into()),
1418            fields: acp::ToolCallUpdateFields {
1419                content: Some(vec!["Thinking hard!".into()]),
1420                ..Default::default()
1421            },
1422        }
1423    );
1424    let update = expect_tool_call_update_fields(&mut events).await;
1425    assert_eq!(
1426        update,
1427        acp::ToolCallUpdate {
1428            id: acp::ToolCallId("1".into()),
1429            fields: acp::ToolCallUpdateFields {
1430                status: Some(acp::ToolCallStatus::Completed),
1431                raw_output: Some("Finished thinking.".into()),
1432                ..Default::default()
1433            },
1434        }
1435    );
1436}
1437
1438#[gpui::test]
1439async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1440    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1441    let fake_model = model.as_fake();
1442
1443    let mut events = thread
1444        .update(cx, |thread, cx| {
1445            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
1446            thread.send(UserMessageId::new(), ["Hello!"], cx)
1447        })
1448        .unwrap();
1449    cx.run_until_parked();
1450
1451    fake_model.send_last_completion_stream_text_chunk("Hey!");
1452    fake_model.end_last_completion_stream();
1453
1454    let mut retry_events = Vec::new();
1455    while let Some(Ok(event)) = events.next().await {
1456        match event {
1457            AgentResponseEvent::Retry(retry_status) => {
1458                retry_events.push(retry_status);
1459            }
1460            AgentResponseEvent::Stop(..) => break,
1461            _ => {}
1462        }
1463    }
1464
1465    assert_eq!(retry_events.len(), 0);
1466    thread.read_with(cx, |thread, _cx| {
1467        assert_eq!(
1468            thread.to_markdown(),
1469            indoc! {"
1470                ## User
1471
1472                Hello!
1473
1474                ## Assistant
1475
1476                Hey!
1477            "}
1478        )
1479    });
1480}
1481
1482#[gpui::test]
1483async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1484    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1485    let fake_model = model.as_fake();
1486
1487    let mut events = thread
1488        .update(cx, |thread, cx| {
1489            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
1490            thread.send(UserMessageId::new(), ["Hello!"], cx)
1491        })
1492        .unwrap();
1493    cx.run_until_parked();
1494
1495    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1496        provider: LanguageModelProviderName::new("Anthropic"),
1497        retry_after: Some(Duration::from_secs(3)),
1498    });
1499    fake_model.end_last_completion_stream();
1500
1501    cx.executor().advance_clock(Duration::from_secs(3));
1502    cx.run_until_parked();
1503
1504    fake_model.send_last_completion_stream_text_chunk("Hey!");
1505    fake_model.end_last_completion_stream();
1506
1507    let mut retry_events = Vec::new();
1508    while let Some(Ok(event)) = events.next().await {
1509        match event {
1510            AgentResponseEvent::Retry(retry_status) => {
1511                retry_events.push(retry_status);
1512            }
1513            AgentResponseEvent::Stop(..) => break,
1514            _ => {}
1515        }
1516    }
1517
1518    assert_eq!(retry_events.len(), 1);
1519    assert!(matches!(
1520        retry_events[0],
1521        acp_thread::RetryStatus { attempt: 1, .. }
1522    ));
1523    thread.read_with(cx, |thread, _cx| {
1524        assert_eq!(
1525            thread.to_markdown(),
1526            indoc! {"
1527                ## User
1528
1529                Hello!
1530
1531                ## Assistant
1532
1533                Hey!
1534            "}
1535        )
1536    });
1537}
1538
1539#[gpui::test]
1540async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1541    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1542    let fake_model = model.as_fake();
1543
1544    let mut events = thread
1545        .update(cx, |thread, cx| {
1546            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
1547            thread.send(UserMessageId::new(), ["Hello!"], cx)
1548        })
1549        .unwrap();
1550    cx.run_until_parked();
1551
1552    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1553        fake_model.send_last_completion_stream_error(
1554            LanguageModelCompletionError::ServerOverloaded {
1555                provider: LanguageModelProviderName::new("Anthropic"),
1556                retry_after: Some(Duration::from_secs(3)),
1557            },
1558        );
1559        fake_model.end_last_completion_stream();
1560        cx.executor().advance_clock(Duration::from_secs(3));
1561        cx.run_until_parked();
1562    }
1563
1564    let mut errors = Vec::new();
1565    let mut retry_events = Vec::new();
1566    while let Some(event) = events.next().await {
1567        match event {
1568            Ok(AgentResponseEvent::Retry(retry_status)) => {
1569                retry_events.push(retry_status);
1570            }
1571            Ok(AgentResponseEvent::Stop(..)) => break,
1572            Err(error) => errors.push(error),
1573            _ => {}
1574        }
1575    }
1576
1577    assert_eq!(
1578        retry_events.len(),
1579        crate::thread::MAX_RETRY_ATTEMPTS as usize
1580    );
1581    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1582        assert_eq!(retry_events[i].attempt, i + 1);
1583    }
1584    assert_eq!(errors.len(), 1);
1585    let error = errors[0]
1586        .downcast_ref::<LanguageModelCompletionError>()
1587        .unwrap();
1588    assert!(matches!(
1589        error,
1590        LanguageModelCompletionError::ServerOverloaded { .. }
1591    ));
1592}
1593
1594/// Filters out the stop events for asserting against in tests
1595fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
1596    result_events
1597        .into_iter()
1598        .filter_map(|event| match event.unwrap() {
1599            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1600            _ => None,
1601        })
1602        .collect()
1603}
1604
1605struct ThreadTest {
1606    model: Arc<dyn LanguageModel>,
1607    thread: Entity<Thread>,
1608    project_context: Entity<ProjectContext>,
1609    fs: Arc<FakeFs>,
1610}
1611
1612enum TestModel {
1613    Sonnet4,
1614    Sonnet4Thinking,
1615    Fake,
1616}
1617
1618impl TestModel {
1619    fn id(&self) -> LanguageModelId {
1620        match self {
1621            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1622            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1623            TestModel::Fake => unreachable!(),
1624        }
1625    }
1626}
1627
1628async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1629    cx.executor().allow_parking();
1630
1631    let fs = FakeFs::new(cx.background_executor.clone());
1632    fs.create_dir(paths::settings_file().parent().unwrap())
1633        .await
1634        .unwrap();
1635    fs.insert_file(
1636        paths::settings_file(),
1637        json!({
1638            "agent": {
1639                "default_profile": "test-profile",
1640                "profiles": {
1641                    "test-profile": {
1642                        "name": "Test Profile",
1643                        "tools": {
1644                            EchoTool.name(): true,
1645                            DelayTool.name(): true,
1646                            WordListTool.name(): true,
1647                            ToolRequiringPermission.name(): true,
1648                            InfiniteTool.name(): true,
1649                        }
1650                    }
1651                }
1652            }
1653        })
1654        .to_string()
1655        .into_bytes(),
1656    )
1657    .await;
1658
1659    cx.update(|cx| {
1660        settings::init(cx);
1661        Project::init_settings(cx);
1662        agent_settings::init(cx);
1663        gpui_tokio::init(cx);
1664        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1665        cx.set_http_client(Arc::new(http_client));
1666
1667        client::init_settings(cx);
1668        let client = Client::production(cx);
1669        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1670        language_model::init(client.clone(), cx);
1671        language_models::init(user_store.clone(), client.clone(), cx);
1672
1673        watch_settings(fs.clone(), cx);
1674    });
1675
1676    let templates = Templates::new();
1677
1678    fs.insert_tree(path!("/test"), json!({})).await;
1679    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1680
1681    let model = cx
1682        .update(|cx| {
1683            if let TestModel::Fake = model {
1684                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1685            } else {
1686                let model_id = model.id();
1687                let models = LanguageModelRegistry::read_global(cx);
1688                let model = models
1689                    .available_models(cx)
1690                    .find(|model| model.id() == model_id)
1691                    .unwrap();
1692
1693                let provider = models.provider(&model.provider_id()).unwrap();
1694                let authenticated = provider.authenticate(cx);
1695
1696                cx.spawn(async move |_cx| {
1697                    authenticated.await.unwrap();
1698                    model
1699                })
1700            }
1701        })
1702        .await;
1703
1704    let project_context = cx.new(|_cx| ProjectContext::default());
1705    let context_server_registry =
1706        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1707    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1708    let thread = cx.new(|cx| {
1709        Thread::new(
1710            project,
1711            project_context.clone(),
1712            context_server_registry,
1713            action_log,
1714            templates,
1715            Some(model.clone()),
1716            cx,
1717        )
1718    });
1719    ThreadTest {
1720        model,
1721        thread,
1722        project_context,
1723        fs,
1724    }
1725}
1726
1727#[cfg(test)]
1728#[ctor::ctor]
1729fn init_logger() {
1730    if std::env::var("RUST_LOG").is_ok() {
1731        env_logger::init();
1732    }
1733}
1734
1735fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1736    let fs = fs.clone();
1737    cx.spawn({
1738        async move |cx| {
1739            let mut new_settings_content_rx = settings::watch_config_file(
1740                cx.background_executor(),
1741                fs,
1742                paths::settings_file().clone(),
1743            );
1744
1745            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1746                cx.update(|cx| {
1747                    SettingsStore::update_global(cx, |settings, cx| {
1748                        settings.set_user_settings(&new_settings_content, cx)
1749                    })
1750                })
1751                .ok();
1752            }
1753        }
1754    })
1755    .detach();
1756}