context_tests.rs

   1use crate::{
   2    AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
   3    ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  20use parking_lot::Mutex;
  21use pretty_assertions::assert_eq;
  22use project::Project;
  23use prompt_store::PromptBuilder;
  24use rand::prelude::*;
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{
  28    cell::RefCell,
  29    env,
  30    ops::Range,
  31    path::Path,
  32    rc::Rc,
  33    sync::{Arc, atomic::AtomicBool},
  34};
  35use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
  36use ui::{IconName, Window};
  37use unindent::Unindent;
  38use util::{
  39    RandomCharIter,
  40    test::{generate_marked_text, marked_text_ranges},
  41};
  42use workspace::Workspace;
  43
  44#[gpui::test]
  45fn test_inserting_and_removing_messages(cx: &mut App) {
  46    init_test(cx);
  47
  48    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  49    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  50    let context = cx.new(|cx| {
  51        AssistantContext::local(
  52            registry,
  53            None,
  54            None,
  55            prompt_builder.clone(),
  56            Arc::new(SlashCommandWorkingSet::default()),
  57            cx,
  58        )
  59    });
  60    let buffer = context.read(cx).buffer.clone();
  61
  62    let message_1 = context.read(cx).message_anchors[0].clone();
  63    assert_eq!(
  64        messages(&context, cx),
  65        vec![(message_1.id, Role::User, 0..0)]
  66    );
  67
  68    let message_2 = context.update(cx, |context, cx| {
  69        context
  70            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  71            .unwrap()
  72    });
  73    assert_eq!(
  74        messages(&context, cx),
  75        vec![
  76            (message_1.id, Role::User, 0..1),
  77            (message_2.id, Role::Assistant, 1..1)
  78        ]
  79    );
  80
  81    buffer.update(cx, |buffer, cx| {
  82        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  83    });
  84    assert_eq!(
  85        messages(&context, cx),
  86        vec![
  87            (message_1.id, Role::User, 0..2),
  88            (message_2.id, Role::Assistant, 2..3)
  89        ]
  90    );
  91
  92    let message_3 = context.update(cx, |context, cx| {
  93        context
  94            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  95            .unwrap()
  96    });
  97    assert_eq!(
  98        messages(&context, cx),
  99        vec![
 100            (message_1.id, Role::User, 0..2),
 101            (message_2.id, Role::Assistant, 2..4),
 102            (message_3.id, Role::User, 4..4)
 103        ]
 104    );
 105
 106    let message_4 = context.update(cx, |context, cx| {
 107        context
 108            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 109            .unwrap()
 110    });
 111    assert_eq!(
 112        messages(&context, cx),
 113        vec![
 114            (message_1.id, Role::User, 0..2),
 115            (message_2.id, Role::Assistant, 2..4),
 116            (message_4.id, Role::User, 4..5),
 117            (message_3.id, Role::User, 5..5),
 118        ]
 119    );
 120
 121    buffer.update(cx, |buffer, cx| {
 122        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 123    });
 124    assert_eq!(
 125        messages(&context, cx),
 126        vec![
 127            (message_1.id, Role::User, 0..2),
 128            (message_2.id, Role::Assistant, 2..4),
 129            (message_4.id, Role::User, 4..6),
 130            (message_3.id, Role::User, 6..7),
 131        ]
 132    );
 133
 134    // Deleting across message boundaries merges the messages.
 135    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 136    assert_eq!(
 137        messages(&context, cx),
 138        vec![
 139            (message_1.id, Role::User, 0..3),
 140            (message_3.id, Role::User, 3..4),
 141        ]
 142    );
 143
 144    // Undoing the deletion should also undo the merge.
 145    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 146    assert_eq!(
 147        messages(&context, cx),
 148        vec![
 149            (message_1.id, Role::User, 0..2),
 150            (message_2.id, Role::Assistant, 2..4),
 151            (message_4.id, Role::User, 4..6),
 152            (message_3.id, Role::User, 6..7),
 153        ]
 154    );
 155
 156    // Redoing the deletion should also redo the merge.
 157    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 158    assert_eq!(
 159        messages(&context, cx),
 160        vec![
 161            (message_1.id, Role::User, 0..3),
 162            (message_3.id, Role::User, 3..4),
 163        ]
 164    );
 165
 166    // Ensure we can still insert after a merged message.
 167    let message_5 = context.update(cx, |context, cx| {
 168        context
 169            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 170            .unwrap()
 171    });
 172    assert_eq!(
 173        messages(&context, cx),
 174        vec![
 175            (message_1.id, Role::User, 0..3),
 176            (message_5.id, Role::System, 3..4),
 177            (message_3.id, Role::User, 4..5)
 178        ]
 179    );
 180}
 181
 182#[gpui::test]
 183fn test_message_splitting(cx: &mut App) {
 184    init_test(cx);
 185
 186    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 187
 188    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 189    let context = cx.new(|cx| {
 190        AssistantContext::local(
 191            registry.clone(),
 192            None,
 193            None,
 194            prompt_builder.clone(),
 195            Arc::new(SlashCommandWorkingSet::default()),
 196            cx,
 197        )
 198    });
 199    let buffer = context.read(cx).buffer.clone();
 200
 201    let message_1 = context.read(cx).message_anchors[0].clone();
 202    assert_eq!(
 203        messages(&context, cx),
 204        vec![(message_1.id, Role::User, 0..0)]
 205    );
 206
 207    buffer.update(cx, |buffer, cx| {
 208        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 209    });
 210
 211    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 212    let message_2 = message_2.unwrap();
 213
 214    // We recycle newlines in the middle of a split message
 215    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 216    assert_eq!(
 217        messages(&context, cx),
 218        vec![
 219            (message_1.id, Role::User, 0..4),
 220            (message_2.id, Role::User, 4..16),
 221        ]
 222    );
 223
 224    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 225    let message_3 = message_3.unwrap();
 226
 227    // We don't recycle newlines at the end of a split message
 228    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 229    assert_eq!(
 230        messages(&context, cx),
 231        vec![
 232            (message_1.id, Role::User, 0..4),
 233            (message_3.id, Role::User, 4..5),
 234            (message_2.id, Role::User, 5..17),
 235        ]
 236    );
 237
 238    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 239    let message_4 = message_4.unwrap();
 240    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 241    assert_eq!(
 242        messages(&context, cx),
 243        vec![
 244            (message_1.id, Role::User, 0..4),
 245            (message_3.id, Role::User, 4..5),
 246            (message_2.id, Role::User, 5..9),
 247            (message_4.id, Role::User, 9..17),
 248        ]
 249    );
 250
 251    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 252    let message_5 = message_5.unwrap();
 253    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 254    assert_eq!(
 255        messages(&context, cx),
 256        vec![
 257            (message_1.id, Role::User, 0..4),
 258            (message_3.id, Role::User, 4..5),
 259            (message_2.id, Role::User, 5..9),
 260            (message_4.id, Role::User, 9..10),
 261            (message_5.id, Role::User, 10..18),
 262        ]
 263    );
 264
 265    let (message_6, message_7) =
 266        context.update(cx, |context, cx| context.split_message(14..16, cx));
 267    let message_6 = message_6.unwrap();
 268    let message_7 = message_7.unwrap();
 269    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 270    assert_eq!(
 271        messages(&context, cx),
 272        vec![
 273            (message_1.id, Role::User, 0..4),
 274            (message_3.id, Role::User, 4..5),
 275            (message_2.id, Role::User, 5..9),
 276            (message_4.id, Role::User, 9..10),
 277            (message_5.id, Role::User, 10..14),
 278            (message_6.id, Role::User, 14..17),
 279            (message_7.id, Role::User, 17..19),
 280        ]
 281    );
 282}
 283
 284#[gpui::test]
 285fn test_messages_for_offsets(cx: &mut App) {
 286    init_test(cx);
 287
 288    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 289    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 290    let context = cx.new(|cx| {
 291        AssistantContext::local(
 292            registry,
 293            None,
 294            None,
 295            prompt_builder.clone(),
 296            Arc::new(SlashCommandWorkingSet::default()),
 297            cx,
 298        )
 299    });
 300    let buffer = context.read(cx).buffer.clone();
 301
 302    let message_1 = context.read(cx).message_anchors[0].clone();
 303    assert_eq!(
 304        messages(&context, cx),
 305        vec![(message_1.id, Role::User, 0..0)]
 306    );
 307
 308    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 309    let message_2 = context
 310        .update(cx, |context, cx| {
 311            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 312        })
 313        .unwrap();
 314    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 315
 316    let message_3 = context
 317        .update(cx, |context, cx| {
 318            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 319        })
 320        .unwrap();
 321    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 322
 323    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 324    assert_eq!(
 325        messages(&context, cx),
 326        vec![
 327            (message_1.id, Role::User, 0..4),
 328            (message_2.id, Role::User, 4..8),
 329            (message_3.id, Role::User, 8..11)
 330        ]
 331    );
 332
 333    assert_eq!(
 334        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 335        [message_1.id, message_2.id, message_3.id]
 336    );
 337    assert_eq!(
 338        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 339        [message_1.id, message_3.id]
 340    );
 341
 342    let message_4 = context
 343        .update(cx, |context, cx| {
 344            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 345        })
 346        .unwrap();
 347    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 348    assert_eq!(
 349        messages(&context, cx),
 350        vec![
 351            (message_1.id, Role::User, 0..4),
 352            (message_2.id, Role::User, 4..8),
 353            (message_3.id, Role::User, 8..12),
 354            (message_4.id, Role::User, 12..12)
 355        ]
 356    );
 357    assert_eq!(
 358        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 359        [message_1.id, message_2.id, message_3.id, message_4.id]
 360    );
 361
 362    fn message_ids_for_offsets(
 363        context: &Entity<AssistantContext>,
 364        offsets: &[usize],
 365        cx: &App,
 366    ) -> Vec<MessageId> {
 367        context
 368            .read(cx)
 369            .messages_for_offsets(offsets.iter().copied(), cx)
 370            .into_iter()
 371            .map(|message| message.id)
 372            .collect()
 373    }
 374}
 375
 376#[gpui::test]
 377async fn test_slash_commands(cx: &mut TestAppContext) {
 378    cx.update(init_test);
 379
 380    let fs = FakeFs::new(cx.background_executor.clone());
 381
 382    fs.insert_tree(
 383        "/test",
 384        json!({
 385            "src": {
 386                "lib.rs": "fn one() -> usize { 1 }",
 387                "main.rs": "
 388                    use crate::one;
 389                    fn main() { one(); }
 390                ".unindent(),
 391            }
 392        }),
 393    )
 394    .await;
 395
 396    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 397    slash_command_registry.register_command(FileSlashCommand, false);
 398
 399    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 400    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 401    let context = cx.new(|cx| {
 402        AssistantContext::local(
 403            registry.clone(),
 404            None,
 405            None,
 406            prompt_builder.clone(),
 407            Arc::new(SlashCommandWorkingSet::default()),
 408            cx,
 409        )
 410    });
 411
 412    #[derive(Default)]
 413    struct ContextRanges {
 414        parsed_commands: HashSet<Range<language::Anchor>>,
 415        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 416        output_sections: HashSet<Range<language::Anchor>>,
 417    }
 418
 419    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 420    context.update(cx, |_, cx| {
 421        cx.subscribe(&context, {
 422            let context_ranges = context_ranges.clone();
 423            move |context, _, event, _| {
 424                let mut context_ranges = context_ranges.borrow_mut();
 425                match event {
 426                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 427                        let command = context.invoked_slash_command(command_id).unwrap();
 428                        context_ranges
 429                            .command_outputs
 430                            .insert(*command_id, command.range.clone());
 431                    }
 432                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 433                        for range in removed {
 434                            context_ranges.parsed_commands.remove(range);
 435                        }
 436                        for command in updated {
 437                            context_ranges
 438                                .parsed_commands
 439                                .insert(command.source_range.clone());
 440                        }
 441                    }
 442                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 443                        context_ranges.output_sections.insert(section.range.clone());
 444                    }
 445                    _ => {}
 446                }
 447            }
 448        })
 449        .detach();
 450    });
 451
 452    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 453
 454    // Insert a slash command
 455    buffer.update(cx, |buffer, cx| {
 456        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 457    });
 458    assert_text_and_context_ranges(
 459        &buffer,
 460        &context_ranges,
 461        &"
 462        «/file src/lib.rs»"
 463            .unindent(),
 464        cx,
 465    );
 466
 467    // Edit the argument of the slash command.
 468    buffer.update(cx, |buffer, cx| {
 469        let edit_offset = buffer.text().find("lib.rs").unwrap();
 470        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 471    });
 472    assert_text_and_context_ranges(
 473        &buffer,
 474        &context_ranges,
 475        &"
 476        «/file src/main.rs»"
 477            .unindent(),
 478        cx,
 479    );
 480
 481    // Edit the name of the slash command, using one that doesn't exist.
 482    buffer.update(cx, |buffer, cx| {
 483        let edit_offset = buffer.text().find("/file").unwrap();
 484        buffer.edit(
 485            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 486            None,
 487            cx,
 488        );
 489    });
 490    assert_text_and_context_ranges(
 491        &buffer,
 492        &context_ranges,
 493        &"
 494        /unknown src/main.rs"
 495            .unindent(),
 496        cx,
 497    );
 498
 499    // Undoing the insertion of an non-existent slash command resorts the previous one.
 500    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 501    assert_text_and_context_ranges(
 502        &buffer,
 503        &context_ranges,
 504        &"
 505        «/file src/main.rs»"
 506            .unindent(),
 507        cx,
 508    );
 509
 510    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 511    context.update(cx, |context, cx| {
 512        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 513        context.insert_command_output(
 514            command_source_range,
 515            "file",
 516            Task::ready(Ok(command_output_rx.boxed())),
 517            true,
 518            cx,
 519        );
 520    });
 521    assert_text_and_context_ranges(
 522        &buffer,
 523        &context_ranges,
 524        &"
 525        ⟦«/file src/main.rs»
 526        …⟧
 527        "
 528        .unindent(),
 529        cx,
 530    );
 531
 532    command_output_tx
 533        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 534            icon: IconName::Ai,
 535            label: "src/main.rs".into(),
 536            metadata: None,
 537        }))
 538        .unwrap();
 539    command_output_tx
 540        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 541        .unwrap();
 542    cx.run_until_parked();
 543    assert_text_and_context_ranges(
 544        &buffer,
 545        &context_ranges,
 546        &"
 547        ⟦«/file src/main.rs»
 548        src/main.rs…⟧
 549        "
 550        .unindent(),
 551        cx,
 552    );
 553
 554    command_output_tx
 555        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 556        .unwrap();
 557    cx.run_until_parked();
 558    assert_text_and_context_ranges(
 559        &buffer,
 560        &context_ranges,
 561        &"
 562        ⟦«/file src/main.rs»
 563        src/main.rs
 564        fn main() {}…⟧
 565        "
 566        .unindent(),
 567        cx,
 568    );
 569
 570    command_output_tx
 571        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 572        .unwrap();
 573    cx.run_until_parked();
 574    assert_text_and_context_ranges(
 575        &buffer,
 576        &context_ranges,
 577        &"
 578        ⟦«/file src/main.rs»
 579        ⟪src/main.rs
 580        fn main() {}⟫…⟧
 581        "
 582        .unindent(),
 583        cx,
 584    );
 585
 586    drop(command_output_tx);
 587    cx.run_until_parked();
 588    assert_text_and_context_ranges(
 589        &buffer,
 590        &context_ranges,
 591        &"
 592        ⟦⟪src/main.rs
 593        fn main() {}⟫⟧
 594        "
 595        .unindent(),
 596        cx,
 597    );
 598
 599    #[track_caller]
 600    fn assert_text_and_context_ranges(
 601        buffer: &Entity<Buffer>,
 602        ranges: &RefCell<ContextRanges>,
 603        expected_marked_text: &str,
 604        cx: &mut TestAppContext,
 605    ) {
 606        let mut actual_marked_text = String::new();
 607        buffer.update(cx, |buffer, _| {
 608            struct Endpoint {
 609                offset: usize,
 610                marker: char,
 611            }
 612
 613            let ranges = ranges.borrow();
 614            let mut endpoints = Vec::new();
 615            for range in ranges.command_outputs.values() {
 616                endpoints.push(Endpoint {
 617                    offset: range.start.to_offset(buffer),
 618                    marker: '',
 619                });
 620            }
 621            for range in ranges.parsed_commands.iter() {
 622                endpoints.push(Endpoint {
 623                    offset: range.start.to_offset(buffer),
 624                    marker: '«',
 625                });
 626            }
 627            for range in ranges.output_sections.iter() {
 628                endpoints.push(Endpoint {
 629                    offset: range.start.to_offset(buffer),
 630                    marker: '',
 631                });
 632            }
 633
 634            for range in ranges.output_sections.iter() {
 635                endpoints.push(Endpoint {
 636                    offset: range.end.to_offset(buffer),
 637                    marker: '',
 638                });
 639            }
 640            for range in ranges.parsed_commands.iter() {
 641                endpoints.push(Endpoint {
 642                    offset: range.end.to_offset(buffer),
 643                    marker: '»',
 644                });
 645            }
 646            for range in ranges.command_outputs.values() {
 647                endpoints.push(Endpoint {
 648                    offset: range.end.to_offset(buffer),
 649                    marker: '',
 650                });
 651            }
 652
 653            endpoints.sort_by_key(|endpoint| endpoint.offset);
 654            let mut offset = 0;
 655            for endpoint in endpoints {
 656                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 657                actual_marked_text.push(endpoint.marker);
 658                offset = endpoint.offset;
 659            }
 660            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 661        });
 662
 663        assert_eq!(actual_marked_text, expected_marked_text);
 664    }
 665}
 666
 667#[gpui::test]
 668async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 669    cx.update(|cx| {
 670        init_test(cx);
 671        cx.update_global(|settings_store: &mut SettingsStore, cx| {
 672            settings_store
 673                .set_user_settings(
 674                    r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 675                    cx,
 676                )
 677                .unwrap()
 678        })
 679    });
 680    let fs = FakeFs::new(cx.executor());
 681    let project = Project::test(fs, [Path::new("/root")], cx).await;
 682
 683    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 684
 685    // Create a new context
 686    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 687    let context = cx.new(|cx| {
 688        AssistantContext::local(
 689            registry.clone(),
 690            Some(project),
 691            None,
 692            prompt_builder.clone(),
 693            Arc::new(SlashCommandWorkingSet::default()),
 694            cx,
 695        )
 696    });
 697
 698    // Insert an assistant message to simulate a response.
 699    let assistant_message_id = context.update(cx, |context, cx| {
 700        let user_message_id = context.messages(cx).next().unwrap().id;
 701        context
 702            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 703            .unwrap()
 704            .id
 705    });
 706
 707    // No edit tags
 708    edit(
 709        &context,
 710        "
 711
 712        «one
 713        two
 714        »",
 715        cx,
 716    );
 717    expect_patches(
 718        &context,
 719        "
 720
 721        one
 722        two
 723        ",
 724        &[],
 725        cx,
 726    );
 727
 728    // Partial edit step tag is added
 729    edit(
 730        &context,
 731        "
 732
 733        one
 734        two
 735        «
 736        <patch»",
 737        cx,
 738    );
 739    expect_patches(
 740        &context,
 741        "
 742
 743        one
 744        two
 745
 746        <patch",
 747        &[],
 748        cx,
 749    );
 750
 751    // The rest of the step tag is added. The unclosed
 752    // step is treated as incomplete.
 753    edit(
 754        &context,
 755        "
 756
 757        one
 758        two
 759
 760        <patch«>
 761        <edit>»",
 762        cx,
 763    );
 764    expect_patches(
 765        &context,
 766        "
 767
 768        one
 769        two
 770
 771        «<patch>
 772        <edit>»",
 773        &[&[]],
 774        cx,
 775    );
 776
 777    // The full patch is added
 778    edit(
 779        &context,
 780        "
 781
 782        one
 783        two
 784
 785        <patch>
 786        <edit>«
 787        <description>add a `two` function</description>
 788        <path>src/lib.rs</path>
 789        <operation>insert_after</operation>
 790        <old_text>fn one</old_text>
 791        <new_text>
 792        fn two() {}
 793        </new_text>
 794        </edit>
 795        </patch>
 796
 797        also,»",
 798        cx,
 799    );
 800    expect_patches(
 801        &context,
 802        "
 803
 804        one
 805        two
 806
 807        «<patch>
 808        <edit>
 809        <description>add a `two` function</description>
 810        <path>src/lib.rs</path>
 811        <operation>insert_after</operation>
 812        <old_text>fn one</old_text>
 813        <new_text>
 814        fn two() {}
 815        </new_text>
 816        </edit>
 817        </patch>
 818        »
 819        also,",
 820        &[&[AssistantEdit {
 821            path: "src/lib.rs".into(),
 822            kind: AssistantEditKind::InsertAfter {
 823                old_text: "fn one".into(),
 824                new_text: "fn two() {}".into(),
 825                description: Some("add a `two` function".into()),
 826            },
 827        }]],
 828        cx,
 829    );
 830
 831    // The step is manually edited.
 832    edit(
 833        &context,
 834        "
 835
 836        one
 837        two
 838
 839        <patch>
 840        <edit>
 841        <description>add a `two` function</description>
 842        <path>src/lib.rs</path>
 843        <operation>insert_after</operation>
 844        <old_text>«fn zero»</old_text>
 845        <new_text>
 846        fn two() {}
 847        </new_text>
 848        </edit>
 849        </patch>
 850
 851        also,",
 852        cx,
 853    );
 854    expect_patches(
 855        &context,
 856        "
 857
 858        one
 859        two
 860
 861        «<patch>
 862        <edit>
 863        <description>add a `two` function</description>
 864        <path>src/lib.rs</path>
 865        <operation>insert_after</operation>
 866        <old_text>fn zero</old_text>
 867        <new_text>
 868        fn two() {}
 869        </new_text>
 870        </edit>
 871        </patch>
 872        »
 873        also,",
 874        &[&[AssistantEdit {
 875            path: "src/lib.rs".into(),
 876            kind: AssistantEditKind::InsertAfter {
 877                old_text: "fn zero".into(),
 878                new_text: "fn two() {}".into(),
 879                description: Some("add a `two` function".into()),
 880            },
 881        }]],
 882        cx,
 883    );
 884
 885    // When setting the message role to User, the steps are cleared.
 886    context.update(cx, |context, cx| {
 887        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 888        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 889    });
 890    expect_patches(
 891        &context,
 892        "
 893
 894        one
 895        two
 896
 897        <patch>
 898        <edit>
 899        <description>add a `two` function</description>
 900        <path>src/lib.rs</path>
 901        <operation>insert_after</operation>
 902        <old_text>fn zero</old_text>
 903        <new_text>
 904        fn two() {}
 905        </new_text>
 906        </edit>
 907        </patch>
 908
 909        also,",
 910        &[],
 911        cx,
 912    );
 913
 914    // When setting the message role back to Assistant, the steps are reparsed.
 915    context.update(cx, |context, cx| {
 916        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 917    });
 918    expect_patches(
 919        &context,
 920        "
 921
 922        one
 923        two
 924
 925        «<patch>
 926        <edit>
 927        <description>add a `two` function</description>
 928        <path>src/lib.rs</path>
 929        <operation>insert_after</operation>
 930        <old_text>fn zero</old_text>
 931        <new_text>
 932        fn two() {}
 933        </new_text>
 934        </edit>
 935        </patch>
 936        »
 937        also,",
 938        &[&[AssistantEdit {
 939            path: "src/lib.rs".into(),
 940            kind: AssistantEditKind::InsertAfter {
 941                old_text: "fn zero".into(),
 942                new_text: "fn two() {}".into(),
 943                description: Some("add a `two` function".into()),
 944            },
 945        }]],
 946        cx,
 947    );
 948
 949    // Ensure steps are re-parsed when deserializing.
 950    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 951    let deserialized_context = cx.new(|cx| {
 952        AssistantContext::deserialize(
 953            serialized_context,
 954            Path::new("").into(),
 955            registry.clone(),
 956            prompt_builder.clone(),
 957            Arc::new(SlashCommandWorkingSet::default()),
 958            None,
 959            None,
 960            cx,
 961        )
 962    });
 963    expect_patches(
 964        &deserialized_context,
 965        "
 966
 967        one
 968        two
 969
 970        «<patch>
 971        <edit>
 972        <description>add a `two` function</description>
 973        <path>src/lib.rs</path>
 974        <operation>insert_after</operation>
 975        <old_text>fn zero</old_text>
 976        <new_text>
 977        fn two() {}
 978        </new_text>
 979        </edit>
 980        </patch>
 981        »
 982        also,",
 983        &[&[AssistantEdit {
 984            path: "src/lib.rs".into(),
 985            kind: AssistantEditKind::InsertAfter {
 986                old_text: "fn zero".into(),
 987                new_text: "fn two() {}".into(),
 988                description: Some("add a `two` function".into()),
 989            },
 990        }]],
 991        cx,
 992    );
 993
 994    fn edit(
 995        context: &Entity<AssistantContext>,
 996        new_text_marked_with_edits: &str,
 997        cx: &mut TestAppContext,
 998    ) {
 999        context.update(cx, |context, cx| {
1000            context.buffer.update(cx, |buffer, cx| {
1001                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1002            });
1003        });
1004        cx.executor().run_until_parked();
1005    }
1006
1007    #[track_caller]
1008    fn expect_patches(
1009        context: &Entity<AssistantContext>,
1010        expected_marked_text: &str,
1011        expected_suggestions: &[&[AssistantEdit]],
1012        cx: &mut TestAppContext,
1013    ) {
1014        let expected_marked_text = expected_marked_text.unindent();
1015        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1016
1017        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1018            context.buffer.read_with(cx, |buffer, _| {
1019                let ranges = context
1020                    .patches
1021                    .iter()
1022                    .map(|entry| entry.range.to_offset(buffer))
1023                    .collect::<Vec<_>>();
1024                (
1025                    buffer.text(),
1026                    ranges,
1027                    context
1028                        .patches
1029                        .iter()
1030                        .map(|step| step.edits.clone())
1031                        .collect::<Vec<_>>(),
1032                )
1033            })
1034        });
1035
1036        assert_eq!(buffer_text, expected_text);
1037
1038        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1039        assert_eq!(actual_marked_text, expected_marked_text);
1040
1041        assert_eq!(
1042            patches
1043                .iter()
1044                .map(|patch| {
1045                    patch
1046                        .iter()
1047                        .map(|edit| {
1048                            let edit = edit.as_ref().unwrap();
1049                            AssistantEdit {
1050                                path: edit.path.clone(),
1051                                kind: edit.kind.clone(),
1052                            }
1053                        })
1054                        .collect::<Vec<_>>()
1055                })
1056                .collect::<Vec<_>>(),
1057            expected_suggestions
1058        );
1059    }
1060}
1061
1062#[gpui::test]
1063async fn test_serialization(cx: &mut TestAppContext) {
1064    cx.update(init_test);
1065
1066    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1067    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1068    let context = cx.new(|cx| {
1069        AssistantContext::local(
1070            registry.clone(),
1071            None,
1072            None,
1073            prompt_builder.clone(),
1074            Arc::new(SlashCommandWorkingSet::default()),
1075            cx,
1076        )
1077    });
1078    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1079    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1080    let message_1 = context.update(cx, |context, cx| {
1081        context
1082            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1083            .unwrap()
1084    });
1085    let message_2 = context.update(cx, |context, cx| {
1086        context
1087            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1088            .unwrap()
1089    });
1090    buffer.update(cx, |buffer, cx| {
1091        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1092        buffer.finalize_last_transaction();
1093    });
1094    let _message_3 = context.update(cx, |context, cx| {
1095        context
1096            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1097            .unwrap()
1098    });
1099    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1100    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1101    assert_eq!(
1102        cx.read(|cx| messages(&context, cx)),
1103        [
1104            (message_0, Role::User, 0..2),
1105            (message_1.id, Role::Assistant, 2..6),
1106            (message_2.id, Role::System, 6..6),
1107        ]
1108    );
1109
1110    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1111    let deserialized_context = cx.new(|cx| {
1112        AssistantContext::deserialize(
1113            serialized_context,
1114            Path::new("").into(),
1115            registry.clone(),
1116            prompt_builder.clone(),
1117            Arc::new(SlashCommandWorkingSet::default()),
1118            None,
1119            None,
1120            cx,
1121        )
1122    });
1123    let deserialized_buffer =
1124        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1125    assert_eq!(
1126        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1127        "a\nb\nc\n"
1128    );
1129    assert_eq!(
1130        cx.read(|cx| messages(&deserialized_context, cx)),
1131        [
1132            (message_0, Role::User, 0..2),
1133            (message_1.id, Role::Assistant, 2..6),
1134            (message_2.id, Role::System, 6..6),
1135        ]
1136    );
1137}
1138
1139#[gpui::test(iterations = 100)]
1140async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1141    cx.update(init_test);
1142
1143    let min_peers = env::var("MIN_PEERS")
1144        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1145        .unwrap_or(2);
1146    let max_peers = env::var("MAX_PEERS")
1147        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1148        .unwrap_or(5);
1149    let operations = env::var("OPERATIONS")
1150        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1151        .unwrap_or(50);
1152
1153    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1154    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1155    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1156    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1157
1158    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1159    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1160    let mut contexts = Vec::new();
1161
1162    let num_peers = rng.gen_range(min_peers..=max_peers);
1163    let context_id = ContextId::new();
1164    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1165    for i in 0..num_peers {
1166        let context = cx.new(|cx| {
1167            AssistantContext::new(
1168                context_id.clone(),
1169                i as ReplicaId,
1170                language::Capability::ReadWrite,
1171                registry.clone(),
1172                prompt_builder.clone(),
1173                Arc::new(SlashCommandWorkingSet::default()),
1174                None,
1175                None,
1176                cx,
1177            )
1178        });
1179
1180        cx.update(|cx| {
1181            cx.subscribe(&context, {
1182                let network = network.clone();
1183                move |_, event, _| {
1184                    if let ContextEvent::Operation(op) = event {
1185                        network
1186                            .lock()
1187                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1188                    }
1189                }
1190            })
1191            .detach();
1192        });
1193
1194        contexts.push(context);
1195        network.lock().add_peer(i as ReplicaId);
1196    }
1197
1198    let mut mutation_count = operations;
1199
1200    while mutation_count > 0
1201        || !network.lock().is_idle()
1202        || network.lock().contains_disconnected_peers()
1203    {
1204        let context_index = rng.gen_range(0..contexts.len());
1205        let context = &contexts[context_index];
1206
1207        match rng.gen_range(0..100) {
1208            0..=29 if mutation_count > 0 => {
1209                log::info!("Context {}: edit buffer", context_index);
1210                context.update(cx, |context, cx| {
1211                    context
1212                        .buffer
1213                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1214                });
1215                mutation_count -= 1;
1216            }
1217            30..=44 if mutation_count > 0 => {
1218                context.update(cx, |context, cx| {
1219                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1220                    log::info!("Context {}: split message at {:?}", context_index, range);
1221                    context.split_message(range, cx);
1222                });
1223                mutation_count -= 1;
1224            }
1225            45..=59 if mutation_count > 0 => {
1226                context.update(cx, |context, cx| {
1227                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1228                        let role = *[Role::User, Role::Assistant, Role::System]
1229                            .choose(&mut rng)
1230                            .unwrap();
1231                        log::info!(
1232                            "Context {}: insert message after {:?} with {:?}",
1233                            context_index,
1234                            message.id,
1235                            role
1236                        );
1237                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1238                    }
1239                });
1240                mutation_count -= 1;
1241            }
1242            60..=74 if mutation_count > 0 => {
1243                context.update(cx, |context, cx| {
1244                    let command_text = "/".to_string()
1245                        + slash_commands
1246                            .command_names()
1247                            .choose(&mut rng)
1248                            .unwrap()
1249                            .clone()
1250                            .as_ref();
1251
1252                    let command_range = context.buffer.update(cx, |buffer, cx| {
1253                        let offset = buffer.random_byte_range(0, &mut rng).start;
1254                        buffer.edit(
1255                            [(offset..offset, format!("\n{}\n", command_text))],
1256                            None,
1257                            cx,
1258                        );
1259                        offset + 1..offset + 1 + command_text.len()
1260                    });
1261
1262                    let output_text = RandomCharIter::new(&mut rng)
1263                        .filter(|c| *c != '\r')
1264                        .take(10)
1265                        .collect::<String>();
1266
1267                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1268                        role: Role::User,
1269                        merge_same_roles: true,
1270                    })];
1271
1272                    let num_sections = rng.gen_range(0..=3);
1273                    let mut section_start = 0;
1274                    for _ in 0..num_sections {
1275                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1276                        while !output_text.is_char_boundary(section_end) {
1277                            section_end += 1;
1278                        }
1279                        events.push(Ok(SlashCommandEvent::StartSection {
1280                            icon: IconName::Ai,
1281                            label: "section".into(),
1282                            metadata: None,
1283                        }));
1284                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1285                            text: output_text[section_start..section_end].to_string(),
1286                            run_commands_in_text: false,
1287                        })));
1288                        events.push(Ok(SlashCommandEvent::EndSection));
1289                        section_start = section_end;
1290                    }
1291
1292                    if section_start < output_text.len() {
1293                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1294                            text: output_text[section_start..].to_string(),
1295                            run_commands_in_text: false,
1296                        })));
1297                    }
1298
1299                    log::info!(
1300                        "Context {}: insert slash command output at {:?} with {:?} events",
1301                        context_index,
1302                        command_range,
1303                        events.len()
1304                    );
1305
1306                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1307                        ..context.buffer.read(cx).anchor_after(command_range.end);
1308                    context.insert_command_output(
1309                        command_range,
1310                        "/command",
1311                        Task::ready(Ok(stream::iter(events).boxed())),
1312                        true,
1313                        cx,
1314                    );
1315                });
1316                cx.run_until_parked();
1317                mutation_count -= 1;
1318            }
1319            75..=84 if mutation_count > 0 => {
1320                context.update(cx, |context, cx| {
1321                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1322                        let new_status = match rng.gen_range(0..3) {
1323                            0 => MessageStatus::Done,
1324                            1 => MessageStatus::Pending,
1325                            _ => MessageStatus::Error(SharedString::from("Random error")),
1326                        };
1327                        log::info!(
1328                            "Context {}: update message {:?} status to {:?}",
1329                            context_index,
1330                            message.id,
1331                            new_status
1332                        );
1333                        context.update_metadata(message.id, cx, |metadata| {
1334                            metadata.status = new_status;
1335                        });
1336                    }
1337                });
1338                mutation_count -= 1;
1339            }
1340            _ => {
1341                let replica_id = context_index as ReplicaId;
1342                if network.lock().is_disconnected(replica_id) {
1343                    network.lock().reconnect_peer(replica_id, 0);
1344
1345                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1346                        let host_context = &contexts[0].read(cx);
1347                        let guest_context = context.read(cx);
1348                        (
1349                            guest_context.serialize_ops(&host_context.version(cx), cx),
1350                            host_context.serialize_ops(&guest_context.version(cx), cx),
1351                        )
1352                    });
1353                    let ops_to_send = ops_to_send.await;
1354                    let ops_to_receive = ops_to_receive
1355                        .await
1356                        .into_iter()
1357                        .map(ContextOperation::from_proto)
1358                        .collect::<Result<Vec<_>>>()
1359                        .unwrap();
1360                    log::info!(
1361                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1362                        context_index,
1363                        ops_to_send.len(),
1364                        ops_to_receive.len()
1365                    );
1366
1367                    network.lock().broadcast(replica_id, ops_to_send);
1368                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1369                } else if rng.gen_bool(0.1) && replica_id != 0 {
1370                    log::info!("Context {}: disconnecting", context_index);
1371                    network.lock().disconnect_peer(replica_id);
1372                } else if network.lock().has_unreceived(replica_id) {
1373                    log::info!("Context {}: applying operations", context_index);
1374                    let ops = network.lock().receive(replica_id);
1375                    let ops = ops
1376                        .into_iter()
1377                        .map(ContextOperation::from_proto)
1378                        .collect::<Result<Vec<_>>>()
1379                        .unwrap();
1380                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1381                }
1382            }
1383        }
1384    }
1385
1386    cx.read(|cx| {
1387        let first_context = contexts[0].read(cx);
1388        for context in &contexts[1..] {
1389            let context = context.read(cx);
1390            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1391            assert_eq!(
1392                context.buffer.read(cx).text(),
1393                first_context.buffer.read(cx).text(),
1394                "Context {} text != Context 0 text",
1395                context.buffer.read(cx).replica_id()
1396            );
1397            assert_eq!(
1398                context.message_anchors,
1399                first_context.message_anchors,
1400                "Context {} messages != Context 0 messages",
1401                context.buffer.read(cx).replica_id()
1402            );
1403            assert_eq!(
1404                context.messages_metadata,
1405                first_context.messages_metadata,
1406                "Context {} message metadata != Context 0 message metadata",
1407                context.buffer.read(cx).replica_id()
1408            );
1409            assert_eq!(
1410                context.slash_command_output_sections,
1411                first_context.slash_command_output_sections,
1412                "Context {} slash command output sections != Context 0 slash command output sections",
1413                context.buffer.read(cx).replica_id()
1414            );
1415        }
1416    });
1417}
1418
1419#[gpui::test]
1420fn test_mark_cache_anchors(cx: &mut App) {
1421    init_test(cx);
1422
1423    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1424    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1425    let context = cx.new(|cx| {
1426        AssistantContext::local(
1427            registry,
1428            None,
1429            None,
1430            prompt_builder.clone(),
1431            Arc::new(SlashCommandWorkingSet::default()),
1432            cx,
1433        )
1434    });
1435    let buffer = context.read(cx).buffer.clone();
1436
1437    // Create a test cache configuration
1438    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1439        max_cache_anchors: 3,
1440        should_speculate: true,
1441        min_total_token: 10,
1442    });
1443
1444    let message_1 = context.read(cx).message_anchors[0].clone();
1445
1446    context.update(cx, |context, cx| {
1447        context.mark_cache_anchors(cache_configuration, false, cx)
1448    });
1449
1450    assert_eq!(
1451        messages_cache(&context, cx)
1452            .iter()
1453            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1454            .count(),
1455        0,
1456        "Empty messages should not have any cache anchors."
1457    );
1458
1459    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1460    let message_2 = context
1461        .update(cx, |context, cx| {
1462            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1463        })
1464        .unwrap();
1465
1466    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1467    let message_3 = context
1468        .update(cx, |context, cx| {
1469            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1470        })
1471        .unwrap();
1472    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1473
1474    context.update(cx, |context, cx| {
1475        context.mark_cache_anchors(cache_configuration, false, cx)
1476    });
1477    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1478    assert_eq!(
1479        messages_cache(&context, cx)
1480            .iter()
1481            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1482            .count(),
1483        0,
1484        "Messages should not be marked for cache before going over the token minimum."
1485    );
1486    context.update(cx, |context, _| {
1487        context.token_count = Some(20);
1488    });
1489
1490    context.update(cx, |context, cx| {
1491        context.mark_cache_anchors(cache_configuration, true, cx)
1492    });
1493    assert_eq!(
1494        messages_cache(&context, cx)
1495            .iter()
1496            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1497            .collect::<Vec<bool>>(),
1498        vec![true, true, false],
1499        "Last message should not be an anchor on speculative request."
1500    );
1501
1502    context
1503        .update(cx, |context, cx| {
1504            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1505        })
1506        .unwrap();
1507
1508    context.update(cx, |context, cx| {
1509        context.mark_cache_anchors(cache_configuration, false, cx)
1510    });
1511    assert_eq!(
1512        messages_cache(&context, cx)
1513            .iter()
1514            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1515            .collect::<Vec<bool>>(),
1516        vec![false, true, true, false],
1517        "Most recent message should also be cached if not a speculative request."
1518    );
1519    context.update(cx, |context, cx| {
1520        context.update_cache_status_for_completion(cx)
1521    });
1522    assert_eq!(
1523        messages_cache(&context, cx)
1524            .iter()
1525            .map(|(_, cache)| cache
1526                .as_ref()
1527                .map_or(None, |cache| Some(cache.status.clone())))
1528            .collect::<Vec<Option<CacheStatus>>>(),
1529        vec![
1530            Some(CacheStatus::Cached),
1531            Some(CacheStatus::Cached),
1532            Some(CacheStatus::Cached),
1533            None
1534        ],
1535        "All user messages prior to anchor should be marked as cached."
1536    );
1537
1538    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1539    context.update(cx, |context, cx| {
1540        context.mark_cache_anchors(cache_configuration, false, cx)
1541    });
1542    assert_eq!(
1543        messages_cache(&context, cx)
1544            .iter()
1545            .map(|(_, cache)| cache
1546                .as_ref()
1547                .map_or(None, |cache| Some(cache.status.clone())))
1548            .collect::<Vec<Option<CacheStatus>>>(),
1549        vec![
1550            Some(CacheStatus::Cached),
1551            Some(CacheStatus::Cached),
1552            Some(CacheStatus::Pending),
1553            None
1554        ],
1555        "Modifying a message should invalidate it's cache but leave previous messages."
1556    );
1557    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1558    context.update(cx, |context, cx| {
1559        context.mark_cache_anchors(cache_configuration, false, cx)
1560    });
1561    assert_eq!(
1562        messages_cache(&context, cx)
1563            .iter()
1564            .map(|(_, cache)| cache
1565                .as_ref()
1566                .map_or(None, |cache| Some(cache.status.clone())))
1567            .collect::<Vec<Option<CacheStatus>>>(),
1568        vec![
1569            Some(CacheStatus::Pending),
1570            Some(CacheStatus::Pending),
1571            Some(CacheStatus::Pending),
1572            None
1573        ],
1574        "Modifying a message should invalidate all future messages."
1575    );
1576}
1577
1578fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1579    context
1580        .read(cx)
1581        .messages(cx)
1582        .map(|message| (message.id, message.role, message.offset_range))
1583        .collect()
1584}
1585
1586fn messages_cache(
1587    context: &Entity<AssistantContext>,
1588    cx: &App,
1589) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1590    context
1591        .read(cx)
1592        .messages(cx)
1593        .map(|message| (message.id, message.cache.clone()))
1594        .collect()
1595}
1596
1597fn init_test(cx: &mut App) {
1598    let settings_store = SettingsStore::test(cx);
1599    prompt_store::init(cx);
1600    LanguageModelRegistry::test(cx);
1601    cx.set_global(settings_store);
1602    language::init(cx);
1603    assistant_settings::init(cx);
1604    Project::init_settings(cx);
1605}
1606
1607#[derive(Clone)]
1608struct FakeSlashCommand(String);
1609
1610impl SlashCommand for FakeSlashCommand {
1611    fn name(&self) -> String {
1612        self.0.clone()
1613    }
1614
1615    fn description(&self) -> String {
1616        format!("Fake slash command: {}", self.0)
1617    }
1618
1619    fn menu_text(&self) -> String {
1620        format!("Run fake command: {}", self.0)
1621    }
1622
1623    fn complete_argument(
1624        self: Arc<Self>,
1625        _arguments: &[String],
1626        _cancel: Arc<AtomicBool>,
1627        _workspace: Option<WeakEntity<Workspace>>,
1628        _window: &mut Window,
1629        _cx: &mut App,
1630    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1631        Task::ready(Ok(vec![]))
1632    }
1633
1634    fn requires_argument(&self) -> bool {
1635        false
1636    }
1637
1638    fn run(
1639        self: Arc<Self>,
1640        _arguments: &[String],
1641        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1642        _context_buffer: BufferSnapshot,
1643        _workspace: WeakEntity<Workspace>,
1644        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1645        _window: &mut Window,
1646        _cx: &mut App,
1647    ) -> Task<SlashCommandResult> {
1648        Task::ready(Ok(SlashCommandOutput {
1649            text: format!("Executed fake command: {}", self.0),
1650            sections: vec![],
1651            run_commands_in_text: false,
1652        }
1653        .to_event_stream()))
1654    }
1655}