context_tests.rs

   1use super::{AssistantEdit, MessageCacheMetadata};
   2use crate::slash_command_working_set::SlashCommandWorkingSet;
   3use crate::ToolWorkingSet;
   4use crate::{
   5    assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
   6    Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId,
   7    MessageStatus, PromptBuilder,
   8};
   9use anyhow::Result;
  10use assistant_slash_command::{
  11    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
  12    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult,
  13};
  14use collections::{HashMap, HashSet};
  15use fs::FakeFs;
  16use futures::{
  17    channel::mpsc,
  18    stream::{self, StreamExt},
  19};
  20use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  21use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  22use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  23use parking_lot::Mutex;
  24use pretty_assertions::assert_eq;
  25use project::Project;
  26use rand::prelude::*;
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{
  30    cell::RefCell,
  31    env,
  32    ops::Range,
  33    path::Path,
  34    rc::Rc,
  35    sync::{atomic::AtomicBool, Arc},
  36};
  37use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
  38use ui::{Context as _, IconName, WindowContext};
  39use unindent::Unindent;
  40use util::{
  41    test::{generate_marked_text, marked_text_ranges},
  42    RandomCharIter,
  43};
  44use workspace::Workspace;
  45
  46#[gpui::test]
  47fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  48    let settings_store = SettingsStore::test(cx);
  49    LanguageModelRegistry::test(cx);
  50    cx.set_global(settings_store);
  51    assistant_panel::init(cx);
  52    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  53    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  54    let context = cx.new_model(|cx| {
  55        Context::local(
  56            registry,
  57            None,
  58            None,
  59            prompt_builder.clone(),
  60            Arc::new(SlashCommandWorkingSet::default()),
  61            Arc::new(ToolWorkingSet::default()),
  62            cx,
  63        )
  64    });
  65    let buffer = context.read(cx).buffer.clone();
  66
  67    let message_1 = context.read(cx).message_anchors[0].clone();
  68    assert_eq!(
  69        messages(&context, cx),
  70        vec![(message_1.id, Role::User, 0..0)]
  71    );
  72
  73    let message_2 = context.update(cx, |context, cx| {
  74        context
  75            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  76            .unwrap()
  77    });
  78    assert_eq!(
  79        messages(&context, cx),
  80        vec![
  81            (message_1.id, Role::User, 0..1),
  82            (message_2.id, Role::Assistant, 1..1)
  83        ]
  84    );
  85
  86    buffer.update(cx, |buffer, cx| {
  87        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  88    });
  89    assert_eq!(
  90        messages(&context, cx),
  91        vec![
  92            (message_1.id, Role::User, 0..2),
  93            (message_2.id, Role::Assistant, 2..3)
  94        ]
  95    );
  96
  97    let message_3 = context.update(cx, |context, cx| {
  98        context
  99            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 100            .unwrap()
 101    });
 102    assert_eq!(
 103        messages(&context, cx),
 104        vec![
 105            (message_1.id, Role::User, 0..2),
 106            (message_2.id, Role::Assistant, 2..4),
 107            (message_3.id, Role::User, 4..4)
 108        ]
 109    );
 110
 111    let message_4 = context.update(cx, |context, cx| {
 112        context
 113            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 114            .unwrap()
 115    });
 116    assert_eq!(
 117        messages(&context, cx),
 118        vec![
 119            (message_1.id, Role::User, 0..2),
 120            (message_2.id, Role::Assistant, 2..4),
 121            (message_4.id, Role::User, 4..5),
 122            (message_3.id, Role::User, 5..5),
 123        ]
 124    );
 125
 126    buffer.update(cx, |buffer, cx| {
 127        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 128    });
 129    assert_eq!(
 130        messages(&context, cx),
 131        vec![
 132            (message_1.id, Role::User, 0..2),
 133            (message_2.id, Role::Assistant, 2..4),
 134            (message_4.id, Role::User, 4..6),
 135            (message_3.id, Role::User, 6..7),
 136        ]
 137    );
 138
 139    // Deleting across message boundaries merges the messages.
 140    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 141    assert_eq!(
 142        messages(&context, cx),
 143        vec![
 144            (message_1.id, Role::User, 0..3),
 145            (message_3.id, Role::User, 3..4),
 146        ]
 147    );
 148
 149    // Undoing the deletion should also undo the merge.
 150    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 151    assert_eq!(
 152        messages(&context, cx),
 153        vec![
 154            (message_1.id, Role::User, 0..2),
 155            (message_2.id, Role::Assistant, 2..4),
 156            (message_4.id, Role::User, 4..6),
 157            (message_3.id, Role::User, 6..7),
 158        ]
 159    );
 160
 161    // Redoing the deletion should also redo the merge.
 162    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 163    assert_eq!(
 164        messages(&context, cx),
 165        vec![
 166            (message_1.id, Role::User, 0..3),
 167            (message_3.id, Role::User, 3..4),
 168        ]
 169    );
 170
 171    // Ensure we can still insert after a merged message.
 172    let message_5 = context.update(cx, |context, cx| {
 173        context
 174            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 175            .unwrap()
 176    });
 177    assert_eq!(
 178        messages(&context, cx),
 179        vec![
 180            (message_1.id, Role::User, 0..3),
 181            (message_5.id, Role::System, 3..4),
 182            (message_3.id, Role::User, 4..5)
 183        ]
 184    );
 185}
 186
 187#[gpui::test]
 188fn test_message_splitting(cx: &mut AppContext) {
 189    let settings_store = SettingsStore::test(cx);
 190    cx.set_global(settings_store);
 191    LanguageModelRegistry::test(cx);
 192    assistant_panel::init(cx);
 193    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 194
 195    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 196    let context = cx.new_model(|cx| {
 197        Context::local(
 198            registry.clone(),
 199            None,
 200            None,
 201            prompt_builder.clone(),
 202            Arc::new(SlashCommandWorkingSet::default()),
 203            Arc::new(ToolWorkingSet::default()),
 204            cx,
 205        )
 206    });
 207    let buffer = context.read(cx).buffer.clone();
 208
 209    let message_1 = context.read(cx).message_anchors[0].clone();
 210    assert_eq!(
 211        messages(&context, cx),
 212        vec![(message_1.id, Role::User, 0..0)]
 213    );
 214
 215    buffer.update(cx, |buffer, cx| {
 216        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 217    });
 218
 219    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 220    let message_2 = message_2.unwrap();
 221
 222    // We recycle newlines in the middle of a split message
 223    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 224    assert_eq!(
 225        messages(&context, cx),
 226        vec![
 227            (message_1.id, Role::User, 0..4),
 228            (message_2.id, Role::User, 4..16),
 229        ]
 230    );
 231
 232    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 233    let message_3 = message_3.unwrap();
 234
 235    // We don't recycle newlines at the end of a split message
 236    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 237    assert_eq!(
 238        messages(&context, cx),
 239        vec![
 240            (message_1.id, Role::User, 0..4),
 241            (message_3.id, Role::User, 4..5),
 242            (message_2.id, Role::User, 5..17),
 243        ]
 244    );
 245
 246    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 247    let message_4 = message_4.unwrap();
 248    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 249    assert_eq!(
 250        messages(&context, cx),
 251        vec![
 252            (message_1.id, Role::User, 0..4),
 253            (message_3.id, Role::User, 4..5),
 254            (message_2.id, Role::User, 5..9),
 255            (message_4.id, Role::User, 9..17),
 256        ]
 257    );
 258
 259    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 260    let message_5 = message_5.unwrap();
 261    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 262    assert_eq!(
 263        messages(&context, cx),
 264        vec![
 265            (message_1.id, Role::User, 0..4),
 266            (message_3.id, Role::User, 4..5),
 267            (message_2.id, Role::User, 5..9),
 268            (message_4.id, Role::User, 9..10),
 269            (message_5.id, Role::User, 10..18),
 270        ]
 271    );
 272
 273    let (message_6, message_7) =
 274        context.update(cx, |context, cx| context.split_message(14..16, cx));
 275    let message_6 = message_6.unwrap();
 276    let message_7 = message_7.unwrap();
 277    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 278    assert_eq!(
 279        messages(&context, cx),
 280        vec![
 281            (message_1.id, Role::User, 0..4),
 282            (message_3.id, Role::User, 4..5),
 283            (message_2.id, Role::User, 5..9),
 284            (message_4.id, Role::User, 9..10),
 285            (message_5.id, Role::User, 10..14),
 286            (message_6.id, Role::User, 14..17),
 287            (message_7.id, Role::User, 17..19),
 288        ]
 289    );
 290}
 291
 292#[gpui::test]
 293fn test_messages_for_offsets(cx: &mut AppContext) {
 294    let settings_store = SettingsStore::test(cx);
 295    LanguageModelRegistry::test(cx);
 296    cx.set_global(settings_store);
 297    assistant_panel::init(cx);
 298    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 299    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 300    let context = cx.new_model(|cx| {
 301        Context::local(
 302            registry,
 303            None,
 304            None,
 305            prompt_builder.clone(),
 306            Arc::new(SlashCommandWorkingSet::default()),
 307            Arc::new(ToolWorkingSet::default()),
 308            cx,
 309        )
 310    });
 311    let buffer = context.read(cx).buffer.clone();
 312
 313    let message_1 = context.read(cx).message_anchors[0].clone();
 314    assert_eq!(
 315        messages(&context, cx),
 316        vec![(message_1.id, Role::User, 0..0)]
 317    );
 318
 319    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 320    let message_2 = context
 321        .update(cx, |context, cx| {
 322            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 323        })
 324        .unwrap();
 325    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 326
 327    let message_3 = context
 328        .update(cx, |context, cx| {
 329            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 330        })
 331        .unwrap();
 332    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 333
 334    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 335    assert_eq!(
 336        messages(&context, cx),
 337        vec![
 338            (message_1.id, Role::User, 0..4),
 339            (message_2.id, Role::User, 4..8),
 340            (message_3.id, Role::User, 8..11)
 341        ]
 342    );
 343
 344    assert_eq!(
 345        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 346        [message_1.id, message_2.id, message_3.id]
 347    );
 348    assert_eq!(
 349        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 350        [message_1.id, message_3.id]
 351    );
 352
 353    let message_4 = context
 354        .update(cx, |context, cx| {
 355            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 356        })
 357        .unwrap();
 358    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 359    assert_eq!(
 360        messages(&context, cx),
 361        vec![
 362            (message_1.id, Role::User, 0..4),
 363            (message_2.id, Role::User, 4..8),
 364            (message_3.id, Role::User, 8..12),
 365            (message_4.id, Role::User, 12..12)
 366        ]
 367    );
 368    assert_eq!(
 369        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 370        [message_1.id, message_2.id, message_3.id, message_4.id]
 371    );
 372
 373    fn message_ids_for_offsets(
 374        context: &Model<Context>,
 375        offsets: &[usize],
 376        cx: &AppContext,
 377    ) -> Vec<MessageId> {
 378        context
 379            .read(cx)
 380            .messages_for_offsets(offsets.iter().copied(), cx)
 381            .into_iter()
 382            .map(|message| message.id)
 383            .collect()
 384    }
 385}
 386
 387#[gpui::test]
 388async fn test_slash_commands(cx: &mut TestAppContext) {
 389    let settings_store = cx.update(SettingsStore::test);
 390    cx.set_global(settings_store);
 391    cx.update(LanguageModelRegistry::test);
 392    cx.update(Project::init_settings);
 393    cx.update(assistant_panel::init);
 394    let fs = FakeFs::new(cx.background_executor.clone());
 395
 396    fs.insert_tree(
 397        "/test",
 398        json!({
 399            "src": {
 400                "lib.rs": "fn one() -> usize { 1 }",
 401                "main.rs": "
 402                    use crate::one;
 403                    fn main() { one(); }
 404                ".unindent(),
 405            }
 406        }),
 407    )
 408    .await;
 409
 410    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 411    slash_command_registry.register_command(file_command::FileSlashCommand, false);
 412
 413    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 414    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 415    let context = cx.new_model(|cx| {
 416        Context::local(
 417            registry.clone(),
 418            None,
 419            None,
 420            prompt_builder.clone(),
 421            Arc::new(SlashCommandWorkingSet::default()),
 422            Arc::new(ToolWorkingSet::default()),
 423            cx,
 424        )
 425    });
 426
 427    #[derive(Default)]
 428    struct ContextRanges {
 429        parsed_commands: HashSet<Range<language::Anchor>>,
 430        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 431        output_sections: HashSet<Range<language::Anchor>>,
 432    }
 433
 434    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 435    context.update(cx, |_, cx| {
 436        cx.subscribe(&context, {
 437            let context_ranges = context_ranges.clone();
 438            move |context, _, event, _| {
 439                let mut context_ranges = context_ranges.borrow_mut();
 440                match event {
 441                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 442                        let command = context.invoked_slash_command(command_id).unwrap();
 443                        context_ranges
 444                            .command_outputs
 445                            .insert(*command_id, command.range.clone());
 446                    }
 447                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 448                        for range in removed {
 449                            context_ranges.parsed_commands.remove(range);
 450                        }
 451                        for command in updated {
 452                            context_ranges
 453                                .parsed_commands
 454                                .insert(command.source_range.clone());
 455                        }
 456                    }
 457                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 458                        context_ranges.output_sections.insert(section.range.clone());
 459                    }
 460                    _ => {}
 461                }
 462            }
 463        })
 464        .detach();
 465    });
 466
 467    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 468
 469    // Insert a slash command
 470    buffer.update(cx, |buffer, cx| {
 471        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 472    });
 473    assert_text_and_context_ranges(
 474        &buffer,
 475        &context_ranges,
 476        &"
 477        «/file src/lib.rs»"
 478            .unindent(),
 479        cx,
 480    );
 481
 482    // Edit the argument of the slash command.
 483    buffer.update(cx, |buffer, cx| {
 484        let edit_offset = buffer.text().find("lib.rs").unwrap();
 485        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 486    });
 487    assert_text_and_context_ranges(
 488        &buffer,
 489        &context_ranges,
 490        &"
 491        «/file src/main.rs»"
 492            .unindent(),
 493        cx,
 494    );
 495
 496    // Edit the name of the slash command, using one that doesn't exist.
 497    buffer.update(cx, |buffer, cx| {
 498        let edit_offset = buffer.text().find("/file").unwrap();
 499        buffer.edit(
 500            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 501            None,
 502            cx,
 503        );
 504    });
 505    assert_text_and_context_ranges(
 506        &buffer,
 507        &context_ranges,
 508        &"
 509        /unknown src/main.rs"
 510            .unindent(),
 511        cx,
 512    );
 513
 514    // Undoing the insertion of an non-existent slash command resorts the previous one.
 515    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 516    assert_text_and_context_ranges(
 517        &buffer,
 518        &context_ranges,
 519        &"
 520        «/file src/main.rs»"
 521            .unindent(),
 522        cx,
 523    );
 524
 525    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 526    context.update(cx, |context, cx| {
 527        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 528        context.insert_command_output(
 529            command_source_range,
 530            "file",
 531            Task::ready(Ok(command_output_rx.boxed())),
 532            true,
 533            cx,
 534        );
 535    });
 536    assert_text_and_context_ranges(
 537        &buffer,
 538        &context_ranges,
 539        &"
 540        ⟦«/file src/main.rs»
 541        …⟧
 542        "
 543        .unindent(),
 544        cx,
 545    );
 546
 547    command_output_tx
 548        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 549            icon: IconName::Ai,
 550            label: "src/main.rs".into(),
 551            metadata: None,
 552        }))
 553        .unwrap();
 554    command_output_tx
 555        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 556        .unwrap();
 557    cx.run_until_parked();
 558    assert_text_and_context_ranges(
 559        &buffer,
 560        &context_ranges,
 561        &"
 562        ⟦«/file src/main.rs»
 563        src/main.rs…⟧
 564        "
 565        .unindent(),
 566        cx,
 567    );
 568
 569    command_output_tx
 570        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 571        .unwrap();
 572    cx.run_until_parked();
 573    assert_text_and_context_ranges(
 574        &buffer,
 575        &context_ranges,
 576        &"
 577        ⟦«/file src/main.rs»
 578        src/main.rs
 579        fn main() {}…⟧
 580        "
 581        .unindent(),
 582        cx,
 583    );
 584
 585    command_output_tx
 586        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 587        .unwrap();
 588    cx.run_until_parked();
 589    assert_text_and_context_ranges(
 590        &buffer,
 591        &context_ranges,
 592        &"
 593        ⟦«/file src/main.rs»
 594        ⟪src/main.rs
 595        fn main() {}⟫…⟧
 596        "
 597        .unindent(),
 598        cx,
 599    );
 600
 601    drop(command_output_tx);
 602    cx.run_until_parked();
 603    assert_text_and_context_ranges(
 604        &buffer,
 605        &context_ranges,
 606        &"
 607        ⟦⟪src/main.rs
 608        fn main() {}⟫⟧
 609        "
 610        .unindent(),
 611        cx,
 612    );
 613
 614    #[track_caller]
 615    fn assert_text_and_context_ranges(
 616        buffer: &Model<Buffer>,
 617        ranges: &RefCell<ContextRanges>,
 618        expected_marked_text: &str,
 619        cx: &mut TestAppContext,
 620    ) {
 621        let mut actual_marked_text = String::new();
 622        buffer.update(cx, |buffer, _| {
 623            struct Endpoint {
 624                offset: usize,
 625                marker: char,
 626            }
 627
 628            let ranges = ranges.borrow();
 629            let mut endpoints = Vec::new();
 630            for range in ranges.command_outputs.values() {
 631                endpoints.push(Endpoint {
 632                    offset: range.start.to_offset(buffer),
 633                    marker: '',
 634                });
 635            }
 636            for range in ranges.parsed_commands.iter() {
 637                endpoints.push(Endpoint {
 638                    offset: range.start.to_offset(buffer),
 639                    marker: '«',
 640                });
 641            }
 642            for range in ranges.output_sections.iter() {
 643                endpoints.push(Endpoint {
 644                    offset: range.start.to_offset(buffer),
 645                    marker: '',
 646                });
 647            }
 648
 649            for range in ranges.output_sections.iter() {
 650                endpoints.push(Endpoint {
 651                    offset: range.end.to_offset(buffer),
 652                    marker: '',
 653                });
 654            }
 655            for range in ranges.parsed_commands.iter() {
 656                endpoints.push(Endpoint {
 657                    offset: range.end.to_offset(buffer),
 658                    marker: '»',
 659                });
 660            }
 661            for range in ranges.command_outputs.values() {
 662                endpoints.push(Endpoint {
 663                    offset: range.end.to_offset(buffer),
 664                    marker: '',
 665                });
 666            }
 667
 668            endpoints.sort_by_key(|endpoint| endpoint.offset);
 669            let mut offset = 0;
 670            for endpoint in endpoints {
 671                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 672                actual_marked_text.push(endpoint.marker);
 673                offset = endpoint.offset;
 674            }
 675            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 676        });
 677
 678        assert_eq!(actual_marked_text, expected_marked_text);
 679    }
 680}
 681
 682#[gpui::test]
 683async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 684    cx.update(prompt_library::init);
 685    let mut settings_store = cx.update(SettingsStore::test);
 686    cx.update(|cx| {
 687        settings_store
 688            .set_user_settings(
 689                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 690                cx,
 691            )
 692            .unwrap()
 693    });
 694    cx.set_global(settings_store);
 695    cx.update(language::init);
 696    cx.update(Project::init_settings);
 697    let fs = FakeFs::new(cx.executor());
 698    let project = Project::test(fs, [Path::new("/root")], cx).await;
 699    cx.update(LanguageModelRegistry::test);
 700
 701    cx.update(assistant_panel::init);
 702    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 703
 704    // Create a new context
 705    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 706    let context = cx.new_model(|cx| {
 707        Context::local(
 708            registry.clone(),
 709            Some(project),
 710            None,
 711            prompt_builder.clone(),
 712            Arc::new(SlashCommandWorkingSet::default()),
 713            Arc::new(ToolWorkingSet::default()),
 714            cx,
 715        )
 716    });
 717
 718    // Insert an assistant message to simulate a response.
 719    let assistant_message_id = context.update(cx, |context, cx| {
 720        let user_message_id = context.messages(cx).next().unwrap().id;
 721        context
 722            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 723            .unwrap()
 724            .id
 725    });
 726
 727    // No edit tags
 728    edit(
 729        &context,
 730        "
 731
 732        «one
 733        two
 734        »",
 735        cx,
 736    );
 737    expect_patches(
 738        &context,
 739        "
 740
 741        one
 742        two
 743        ",
 744        &[],
 745        cx,
 746    );
 747
 748    // Partial edit step tag is added
 749    edit(
 750        &context,
 751        "
 752
 753        one
 754        two
 755        «
 756        <patch»",
 757        cx,
 758    );
 759    expect_patches(
 760        &context,
 761        "
 762
 763        one
 764        two
 765
 766        <patch",
 767        &[],
 768        cx,
 769    );
 770
 771    // The rest of the step tag is added. The unclosed
 772    // step is treated as incomplete.
 773    edit(
 774        &context,
 775        "
 776
 777        one
 778        two
 779
 780        <patch«>
 781        <edit>»",
 782        cx,
 783    );
 784    expect_patches(
 785        &context,
 786        "
 787
 788        one
 789        two
 790
 791        «<patch>
 792        <edit>»",
 793        &[&[]],
 794        cx,
 795    );
 796
 797    // The full patch is added
 798    edit(
 799        &context,
 800        "
 801
 802        one
 803        two
 804
 805        <patch>
 806        <edit>«
 807        <description>add a `two` function</description>
 808        <path>src/lib.rs</path>
 809        <operation>insert_after</operation>
 810        <old_text>fn one</old_text>
 811        <new_text>
 812        fn two() {}
 813        </new_text>
 814        </edit>
 815        </patch>
 816
 817        also,»",
 818        cx,
 819    );
 820    expect_patches(
 821        &context,
 822        "
 823
 824        one
 825        two
 826
 827        «<patch>
 828        <edit>
 829        <description>add a `two` function</description>
 830        <path>src/lib.rs</path>
 831        <operation>insert_after</operation>
 832        <old_text>fn one</old_text>
 833        <new_text>
 834        fn two() {}
 835        </new_text>
 836        </edit>
 837        </patch>
 838        »
 839        also,",
 840        &[&[AssistantEdit {
 841            path: "src/lib.rs".into(),
 842            kind: AssistantEditKind::InsertAfter {
 843                old_text: "fn one".into(),
 844                new_text: "fn two() {}".into(),
 845                description: Some("add a `two` function".into()),
 846            },
 847        }]],
 848        cx,
 849    );
 850
 851    // The step is manually edited.
 852    edit(
 853        &context,
 854        "
 855
 856        one
 857        two
 858
 859        <patch>
 860        <edit>
 861        <description>add a `two` function</description>
 862        <path>src/lib.rs</path>
 863        <operation>insert_after</operation>
 864        <old_text>«fn zero»</old_text>
 865        <new_text>
 866        fn two() {}
 867        </new_text>
 868        </edit>
 869        </patch>
 870
 871        also,",
 872        cx,
 873    );
 874    expect_patches(
 875        &context,
 876        "
 877
 878        one
 879        two
 880
 881        «<patch>
 882        <edit>
 883        <description>add a `two` function</description>
 884        <path>src/lib.rs</path>
 885        <operation>insert_after</operation>
 886        <old_text>fn zero</old_text>
 887        <new_text>
 888        fn two() {}
 889        </new_text>
 890        </edit>
 891        </patch>
 892        »
 893        also,",
 894        &[&[AssistantEdit {
 895            path: "src/lib.rs".into(),
 896            kind: AssistantEditKind::InsertAfter {
 897                old_text: "fn zero".into(),
 898                new_text: "fn two() {}".into(),
 899                description: Some("add a `two` function".into()),
 900            },
 901        }]],
 902        cx,
 903    );
 904
 905    // When setting the message role to User, the steps are cleared.
 906    context.update(cx, |context, cx| {
 907        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 908        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 909    });
 910    expect_patches(
 911        &context,
 912        "
 913
 914        one
 915        two
 916
 917        <patch>
 918        <edit>
 919        <description>add a `two` function</description>
 920        <path>src/lib.rs</path>
 921        <operation>insert_after</operation>
 922        <old_text>fn zero</old_text>
 923        <new_text>
 924        fn two() {}
 925        </new_text>
 926        </edit>
 927        </patch>
 928
 929        also,",
 930        &[],
 931        cx,
 932    );
 933
 934    // When setting the message role back to Assistant, the steps are reparsed.
 935    context.update(cx, |context, cx| {
 936        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 937    });
 938    expect_patches(
 939        &context,
 940        "
 941
 942        one
 943        two
 944
 945        «<patch>
 946        <edit>
 947        <description>add a `two` function</description>
 948        <path>src/lib.rs</path>
 949        <operation>insert_after</operation>
 950        <old_text>fn zero</old_text>
 951        <new_text>
 952        fn two() {}
 953        </new_text>
 954        </edit>
 955        </patch>
 956        »
 957        also,",
 958        &[&[AssistantEdit {
 959            path: "src/lib.rs".into(),
 960            kind: AssistantEditKind::InsertAfter {
 961                old_text: "fn zero".into(),
 962                new_text: "fn two() {}".into(),
 963                description: Some("add a `two` function".into()),
 964            },
 965        }]],
 966        cx,
 967    );
 968
 969    // Ensure steps are re-parsed when deserializing.
 970    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 971    let deserialized_context = cx.new_model(|cx| {
 972        Context::deserialize(
 973            serialized_context,
 974            Default::default(),
 975            registry.clone(),
 976            prompt_builder.clone(),
 977            Arc::new(SlashCommandWorkingSet::default()),
 978            Arc::new(ToolWorkingSet::default()),
 979            None,
 980            None,
 981            cx,
 982        )
 983    });
 984    expect_patches(
 985        &deserialized_context,
 986        "
 987
 988        one
 989        two
 990
 991        «<patch>
 992        <edit>
 993        <description>add a `two` function</description>
 994        <path>src/lib.rs</path>
 995        <operation>insert_after</operation>
 996        <old_text>fn zero</old_text>
 997        <new_text>
 998        fn two() {}
 999        </new_text>
1000        </edit>
1001        </patch>
1002        »
1003        also,",
1004        &[&[AssistantEdit {
1005            path: "src/lib.rs".into(),
1006            kind: AssistantEditKind::InsertAfter {
1007                old_text: "fn zero".into(),
1008                new_text: "fn two() {}".into(),
1009                description: Some("add a `two` function".into()),
1010            },
1011        }]],
1012        cx,
1013    );
1014
1015    fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
1016        context.update(cx, |context, cx| {
1017            context.buffer.update(cx, |buffer, cx| {
1018                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1019            });
1020        });
1021        cx.executor().run_until_parked();
1022    }
1023
1024    #[track_caller]
1025    fn expect_patches(
1026        context: &Model<Context>,
1027        expected_marked_text: &str,
1028        expected_suggestions: &[&[AssistantEdit]],
1029        cx: &mut TestAppContext,
1030    ) {
1031        let expected_marked_text = expected_marked_text.unindent();
1032        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1033
1034        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1035            context.buffer.read_with(cx, |buffer, _| {
1036                let ranges = context
1037                    .patches
1038                    .iter()
1039                    .map(|entry| entry.range.to_offset(buffer))
1040                    .collect::<Vec<_>>();
1041                (
1042                    buffer.text(),
1043                    ranges,
1044                    context
1045                        .patches
1046                        .iter()
1047                        .map(|step| step.edits.clone())
1048                        .collect::<Vec<_>>(),
1049                )
1050            })
1051        });
1052
1053        assert_eq!(buffer_text, expected_text);
1054
1055        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1056        assert_eq!(actual_marked_text, expected_marked_text);
1057
1058        assert_eq!(
1059            patches
1060                .iter()
1061                .map(|patch| {
1062                    patch
1063                        .iter()
1064                        .map(|edit| {
1065                            let edit = edit.as_ref().unwrap();
1066                            AssistantEdit {
1067                                path: edit.path.clone(),
1068                                kind: edit.kind.clone(),
1069                            }
1070                        })
1071                        .collect::<Vec<_>>()
1072                })
1073                .collect::<Vec<_>>(),
1074            expected_suggestions
1075        );
1076    }
1077}
1078
1079#[gpui::test]
1080async fn test_serialization(cx: &mut TestAppContext) {
1081    let settings_store = cx.update(SettingsStore::test);
1082    cx.set_global(settings_store);
1083    cx.update(LanguageModelRegistry::test);
1084    cx.update(assistant_panel::init);
1085    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1086    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1087    let context = cx.new_model(|cx| {
1088        Context::local(
1089            registry.clone(),
1090            None,
1091            None,
1092            prompt_builder.clone(),
1093            Arc::new(SlashCommandWorkingSet::default()),
1094            Arc::new(ToolWorkingSet::default()),
1095            cx,
1096        )
1097    });
1098    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1099    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1100    let message_1 = context.update(cx, |context, cx| {
1101        context
1102            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1103            .unwrap()
1104    });
1105    let message_2 = context.update(cx, |context, cx| {
1106        context
1107            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1108            .unwrap()
1109    });
1110    buffer.update(cx, |buffer, cx| {
1111        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1112        buffer.finalize_last_transaction();
1113    });
1114    let _message_3 = context.update(cx, |context, cx| {
1115        context
1116            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1117            .unwrap()
1118    });
1119    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1120    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1121    assert_eq!(
1122        cx.read(|cx| messages(&context, cx)),
1123        [
1124            (message_0, Role::User, 0..2),
1125            (message_1.id, Role::Assistant, 2..6),
1126            (message_2.id, Role::System, 6..6),
1127        ]
1128    );
1129
1130    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1131    let deserialized_context = cx.new_model(|cx| {
1132        Context::deserialize(
1133            serialized_context,
1134            Default::default(),
1135            registry.clone(),
1136            prompt_builder.clone(),
1137            Arc::new(SlashCommandWorkingSet::default()),
1138            Arc::new(ToolWorkingSet::default()),
1139            None,
1140            None,
1141            cx,
1142        )
1143    });
1144    let deserialized_buffer =
1145        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1146    assert_eq!(
1147        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1148        "a\nb\nc\n"
1149    );
1150    assert_eq!(
1151        cx.read(|cx| messages(&deserialized_context, cx)),
1152        [
1153            (message_0, Role::User, 0..2),
1154            (message_1.id, Role::Assistant, 2..6),
1155            (message_2.id, Role::System, 6..6),
1156        ]
1157    );
1158}
1159
1160#[gpui::test(iterations = 100)]
1161async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1162    let min_peers = env::var("MIN_PEERS")
1163        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1164        .unwrap_or(2);
1165    let max_peers = env::var("MAX_PEERS")
1166        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1167        .unwrap_or(5);
1168    let operations = env::var("OPERATIONS")
1169        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1170        .unwrap_or(50);
1171
1172    let settings_store = cx.update(SettingsStore::test);
1173    cx.set_global(settings_store);
1174    cx.update(LanguageModelRegistry::test);
1175
1176    cx.update(assistant_panel::init);
1177    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1178    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1179    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1180    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1181
1182    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1183    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1184    let mut contexts = Vec::new();
1185
1186    let num_peers = rng.gen_range(min_peers..=max_peers);
1187    let context_id = ContextId::new();
1188    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1189    for i in 0..num_peers {
1190        let context = cx.new_model(|cx| {
1191            Context::new(
1192                context_id.clone(),
1193                i as ReplicaId,
1194                language::Capability::ReadWrite,
1195                registry.clone(),
1196                prompt_builder.clone(),
1197                Arc::new(SlashCommandWorkingSet::default()),
1198                Arc::new(ToolWorkingSet::default()),
1199                None,
1200                None,
1201                cx,
1202            )
1203        });
1204
1205        cx.update(|cx| {
1206            cx.subscribe(&context, {
1207                let network = network.clone();
1208                move |_, event, _| {
1209                    if let ContextEvent::Operation(op) = event {
1210                        network
1211                            .lock()
1212                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1213                    }
1214                }
1215            })
1216            .detach();
1217        });
1218
1219        contexts.push(context);
1220        network.lock().add_peer(i as ReplicaId);
1221    }
1222
1223    let mut mutation_count = operations;
1224
1225    while mutation_count > 0
1226        || !network.lock().is_idle()
1227        || network.lock().contains_disconnected_peers()
1228    {
1229        let context_index = rng.gen_range(0..contexts.len());
1230        let context = &contexts[context_index];
1231
1232        match rng.gen_range(0..100) {
1233            0..=29 if mutation_count > 0 => {
1234                log::info!("Context {}: edit buffer", context_index);
1235                context.update(cx, |context, cx| {
1236                    context
1237                        .buffer
1238                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1239                });
1240                mutation_count -= 1;
1241            }
1242            30..=44 if mutation_count > 0 => {
1243                context.update(cx, |context, cx| {
1244                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1245                    log::info!("Context {}: split message at {:?}", context_index, range);
1246                    context.split_message(range, cx);
1247                });
1248                mutation_count -= 1;
1249            }
1250            45..=59 if mutation_count > 0 => {
1251                context.update(cx, |context, cx| {
1252                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1253                        let role = *[Role::User, Role::Assistant, Role::System]
1254                            .choose(&mut rng)
1255                            .unwrap();
1256                        log::info!(
1257                            "Context {}: insert message after {:?} with {:?}",
1258                            context_index,
1259                            message.id,
1260                            role
1261                        );
1262                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1263                    }
1264                });
1265                mutation_count -= 1;
1266            }
1267            60..=74 if mutation_count > 0 => {
1268                context.update(cx, |context, cx| {
1269                    let command_text = "/".to_string()
1270                        + slash_commands
1271                            .command_names()
1272                            .choose(&mut rng)
1273                            .unwrap()
1274                            .clone()
1275                            .as_ref();
1276
1277                    let command_range = context.buffer.update(cx, |buffer, cx| {
1278                        let offset = buffer.random_byte_range(0, &mut rng).start;
1279                        buffer.edit(
1280                            [(offset..offset, format!("\n{}\n", command_text))],
1281                            None,
1282                            cx,
1283                        );
1284                        offset + 1..offset + 1 + command_text.len()
1285                    });
1286
1287                    let output_text = RandomCharIter::new(&mut rng)
1288                        .filter(|c| *c != '\r')
1289                        .take(10)
1290                        .collect::<String>();
1291
1292                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1293                        role: Role::User,
1294                        merge_same_roles: true,
1295                    })];
1296
1297                    let num_sections = rng.gen_range(0..=3);
1298                    let mut section_start = 0;
1299                    for _ in 0..num_sections {
1300                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1301                        while !output_text.is_char_boundary(section_end) {
1302                            section_end += 1;
1303                        }
1304                        events.push(Ok(SlashCommandEvent::StartSection {
1305                            icon: IconName::Ai,
1306                            label: "section".into(),
1307                            metadata: None,
1308                        }));
1309                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1310                            text: output_text[section_start..section_end].to_string(),
1311                            run_commands_in_text: false,
1312                        })));
1313                        events.push(Ok(SlashCommandEvent::EndSection));
1314                        section_start = section_end;
1315                    }
1316
1317                    if section_start < output_text.len() {
1318                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1319                            text: output_text[section_start..].to_string(),
1320                            run_commands_in_text: false,
1321                        })));
1322                    }
1323
1324                    log::info!(
1325                        "Context {}: insert slash command output at {:?} with {:?} events",
1326                        context_index,
1327                        command_range,
1328                        events.len()
1329                    );
1330
1331                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1332                        ..context.buffer.read(cx).anchor_after(command_range.end);
1333                    context.insert_command_output(
1334                        command_range,
1335                        "/command",
1336                        Task::ready(Ok(stream::iter(events).boxed())),
1337                        true,
1338                        cx,
1339                    );
1340                });
1341                cx.run_until_parked();
1342                mutation_count -= 1;
1343            }
1344            75..=84 if mutation_count > 0 => {
1345                context.update(cx, |context, cx| {
1346                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1347                        let new_status = match rng.gen_range(0..3) {
1348                            0 => MessageStatus::Done,
1349                            1 => MessageStatus::Pending,
1350                            _ => MessageStatus::Error(SharedString::from("Random error")),
1351                        };
1352                        log::info!(
1353                            "Context {}: update message {:?} status to {:?}",
1354                            context_index,
1355                            message.id,
1356                            new_status
1357                        );
1358                        context.update_metadata(message.id, cx, |metadata| {
1359                            metadata.status = new_status;
1360                        });
1361                    }
1362                });
1363                mutation_count -= 1;
1364            }
1365            _ => {
1366                let replica_id = context_index as ReplicaId;
1367                if network.lock().is_disconnected(replica_id) {
1368                    network.lock().reconnect_peer(replica_id, 0);
1369
1370                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1371                        let host_context = &contexts[0].read(cx);
1372                        let guest_context = context.read(cx);
1373                        (
1374                            guest_context.serialize_ops(&host_context.version(cx), cx),
1375                            host_context.serialize_ops(&guest_context.version(cx), cx),
1376                        )
1377                    });
1378                    let ops_to_send = ops_to_send.await;
1379                    let ops_to_receive = ops_to_receive
1380                        .await
1381                        .into_iter()
1382                        .map(ContextOperation::from_proto)
1383                        .collect::<Result<Vec<_>>>()
1384                        .unwrap();
1385                    log::info!(
1386                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1387                        context_index,
1388                        ops_to_send.len(),
1389                        ops_to_receive.len()
1390                    );
1391
1392                    network.lock().broadcast(replica_id, ops_to_send);
1393                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1394                } else if rng.gen_bool(0.1) && replica_id != 0 {
1395                    log::info!("Context {}: disconnecting", context_index);
1396                    network.lock().disconnect_peer(replica_id);
1397                } else if network.lock().has_unreceived(replica_id) {
1398                    log::info!("Context {}: applying operations", context_index);
1399                    let ops = network.lock().receive(replica_id);
1400                    let ops = ops
1401                        .into_iter()
1402                        .map(ContextOperation::from_proto)
1403                        .collect::<Result<Vec<_>>>()
1404                        .unwrap();
1405                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1406                }
1407            }
1408        }
1409    }
1410
1411    cx.read(|cx| {
1412        let first_context = contexts[0].read(cx);
1413        for context in &contexts[1..] {
1414            let context = context.read(cx);
1415            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1416            assert_eq!(
1417                context.buffer.read(cx).text(),
1418                first_context.buffer.read(cx).text(),
1419                "Context {} text != Context 0 text",
1420                context.buffer.read(cx).replica_id()
1421            );
1422            assert_eq!(
1423                context.message_anchors,
1424                first_context.message_anchors,
1425                "Context {} messages != Context 0 messages",
1426                context.buffer.read(cx).replica_id()
1427            );
1428            assert_eq!(
1429                context.messages_metadata,
1430                first_context.messages_metadata,
1431                "Context {} message metadata != Context 0 message metadata",
1432                context.buffer.read(cx).replica_id()
1433            );
1434            assert_eq!(
1435                context.slash_command_output_sections,
1436                first_context.slash_command_output_sections,
1437                "Context {} slash command output sections != Context 0 slash command output sections",
1438                context.buffer.read(cx).replica_id()
1439            );
1440        }
1441    });
1442}
1443
1444#[gpui::test]
1445fn test_mark_cache_anchors(cx: &mut AppContext) {
1446    let settings_store = SettingsStore::test(cx);
1447    LanguageModelRegistry::test(cx);
1448    cx.set_global(settings_store);
1449    assistant_panel::init(cx);
1450    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1451    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1452    let context = cx.new_model(|cx| {
1453        Context::local(
1454            registry,
1455            None,
1456            None,
1457            prompt_builder.clone(),
1458            Arc::new(SlashCommandWorkingSet::default()),
1459            Arc::new(ToolWorkingSet::default()),
1460            cx,
1461        )
1462    });
1463    let buffer = context.read(cx).buffer.clone();
1464
1465    // Create a test cache configuration
1466    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1467        max_cache_anchors: 3,
1468        should_speculate: true,
1469        min_total_token: 10,
1470    });
1471
1472    let message_1 = context.read(cx).message_anchors[0].clone();
1473
1474    context.update(cx, |context, cx| {
1475        context.mark_cache_anchors(cache_configuration, false, cx)
1476    });
1477
1478    assert_eq!(
1479        messages_cache(&context, cx)
1480            .iter()
1481            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1482            .count(),
1483        0,
1484        "Empty messages should not have any cache anchors."
1485    );
1486
1487    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1488    let message_2 = context
1489        .update(cx, |context, cx| {
1490            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1491        })
1492        .unwrap();
1493
1494    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1495    let message_3 = context
1496        .update(cx, |context, cx| {
1497            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1498        })
1499        .unwrap();
1500    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1501
1502    context.update(cx, |context, cx| {
1503        context.mark_cache_anchors(cache_configuration, false, cx)
1504    });
1505    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1506    assert_eq!(
1507        messages_cache(&context, cx)
1508            .iter()
1509            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1510            .count(),
1511        0,
1512        "Messages should not be marked for cache before going over the token minimum."
1513    );
1514    context.update(cx, |context, _| {
1515        context.token_count = Some(20);
1516    });
1517
1518    context.update(cx, |context, cx| {
1519        context.mark_cache_anchors(cache_configuration, true, cx)
1520    });
1521    assert_eq!(
1522        messages_cache(&context, cx)
1523            .iter()
1524            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1525            .collect::<Vec<bool>>(),
1526        vec![true, true, false],
1527        "Last message should not be an anchor on speculative request."
1528    );
1529
1530    context
1531        .update(cx, |context, cx| {
1532            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1533        })
1534        .unwrap();
1535
1536    context.update(cx, |context, cx| {
1537        context.mark_cache_anchors(cache_configuration, false, cx)
1538    });
1539    assert_eq!(
1540        messages_cache(&context, cx)
1541            .iter()
1542            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1543            .collect::<Vec<bool>>(),
1544        vec![false, true, true, false],
1545        "Most recent message should also be cached if not a speculative request."
1546    );
1547    context.update(cx, |context, cx| {
1548        context.update_cache_status_for_completion(cx)
1549    });
1550    assert_eq!(
1551        messages_cache(&context, cx)
1552            .iter()
1553            .map(|(_, cache)| cache
1554                .as_ref()
1555                .map_or(None, |cache| Some(cache.status.clone())))
1556            .collect::<Vec<Option<CacheStatus>>>(),
1557        vec![
1558            Some(CacheStatus::Cached),
1559            Some(CacheStatus::Cached),
1560            Some(CacheStatus::Cached),
1561            None
1562        ],
1563        "All user messages prior to anchor should be marked as cached."
1564    );
1565
1566    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1567    context.update(cx, |context, cx| {
1568        context.mark_cache_anchors(cache_configuration, false, cx)
1569    });
1570    assert_eq!(
1571        messages_cache(&context, cx)
1572            .iter()
1573            .map(|(_, cache)| cache
1574                .as_ref()
1575                .map_or(None, |cache| Some(cache.status.clone())))
1576            .collect::<Vec<Option<CacheStatus>>>(),
1577        vec![
1578            Some(CacheStatus::Cached),
1579            Some(CacheStatus::Cached),
1580            Some(CacheStatus::Pending),
1581            None
1582        ],
1583        "Modifying a message should invalidate it's cache but leave previous messages."
1584    );
1585    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1586    context.update(cx, |context, cx| {
1587        context.mark_cache_anchors(cache_configuration, false, cx)
1588    });
1589    assert_eq!(
1590        messages_cache(&context, cx)
1591            .iter()
1592            .map(|(_, cache)| cache
1593                .as_ref()
1594                .map_or(None, |cache| Some(cache.status.clone())))
1595            .collect::<Vec<Option<CacheStatus>>>(),
1596        vec![
1597            Some(CacheStatus::Pending),
1598            Some(CacheStatus::Pending),
1599            Some(CacheStatus::Pending),
1600            None
1601        ],
1602        "Modifying a message should invalidate all future messages."
1603    );
1604}
1605
1606fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1607    context
1608        .read(cx)
1609        .messages(cx)
1610        .map(|message| (message.id, message.role, message.offset_range))
1611        .collect()
1612}
1613
1614fn messages_cache(
1615    context: &Model<Context>,
1616    cx: &AppContext,
1617) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1618    context
1619        .read(cx)
1620        .messages(cx)
1621        .map(|message| (message.id, message.cache.clone()))
1622        .collect()
1623}
1624
1625#[derive(Clone)]
1626struct FakeSlashCommand(String);
1627
1628impl SlashCommand for FakeSlashCommand {
1629    fn name(&self) -> String {
1630        self.0.clone()
1631    }
1632
1633    fn description(&self) -> String {
1634        format!("Fake slash command: {}", self.0)
1635    }
1636
1637    fn menu_text(&self) -> String {
1638        format!("Run fake command: {}", self.0)
1639    }
1640
1641    fn complete_argument(
1642        self: Arc<Self>,
1643        _arguments: &[String],
1644        _cancel: Arc<AtomicBool>,
1645        _workspace: Option<WeakView<Workspace>>,
1646        _cx: &mut WindowContext,
1647    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1648        Task::ready(Ok(vec![]))
1649    }
1650
1651    fn requires_argument(&self) -> bool {
1652        false
1653    }
1654
1655    fn run(
1656        self: Arc<Self>,
1657        _arguments: &[String],
1658        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1659        _context_buffer: BufferSnapshot,
1660        _workspace: WeakView<Workspace>,
1661        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1662        _cx: &mut WindowContext,
1663    ) -> Task<SlashCommandResult> {
1664        Task::ready(Ok(SlashCommandOutput {
1665            text: format!("Executed fake command: {}", self.0),
1666            sections: vec![],
1667            run_commands_in_text: false,
1668        }
1669        .to_event_stream()))
1670    }
1671}