context_tests.rs

   1use crate::{
   2    AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
   3    ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use assistant_tool::ToolWorkingSet;
  12use collections::{HashMap, HashSet};
  13use fs::FakeFs;
  14use futures::{
  15    channel::mpsc,
  16    stream::{self, StreamExt},
  17};
  18use gpui::{prelude::*, App, Entity, SharedString, Task, TestAppContext, WeakEntity};
  19use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  20use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  21use parking_lot::Mutex;
  22use pretty_assertions::assert_eq;
  23use project::Project;
  24use prompt_library::PromptBuilder;
  25use rand::prelude::*;
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{
  29    cell::RefCell,
  30    env,
  31    ops::Range,
  32    path::Path,
  33    rc::Rc,
  34    sync::{atomic::AtomicBool, Arc},
  35};
  36use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
  37use ui::{IconName, Window};
  38use unindent::Unindent;
  39use util::{
  40    test::{generate_marked_text, marked_text_ranges},
  41    RandomCharIter,
  42};
  43use workspace::Workspace;
  44
  45#[gpui::test]
  46fn test_inserting_and_removing_messages(cx: &mut App) {
  47    let settings_store = SettingsStore::test(cx);
  48    LanguageModelRegistry::test(cx);
  49    cx.set_global(settings_store);
  50    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  51    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  52    let context = cx.new(|cx| {
  53        AssistantContext::local(
  54            registry,
  55            None,
  56            None,
  57            prompt_builder.clone(),
  58            Arc::new(SlashCommandWorkingSet::default()),
  59            Arc::new(ToolWorkingSet::default()),
  60            cx,
  61        )
  62    });
  63    let buffer = context.read(cx).buffer.clone();
  64
  65    let message_1 = context.read(cx).message_anchors[0].clone();
  66    assert_eq!(
  67        messages(&context, cx),
  68        vec![(message_1.id, Role::User, 0..0)]
  69    );
  70
  71    let message_2 = context.update(cx, |context, cx| {
  72        context
  73            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  74            .unwrap()
  75    });
  76    assert_eq!(
  77        messages(&context, cx),
  78        vec![
  79            (message_1.id, Role::User, 0..1),
  80            (message_2.id, Role::Assistant, 1..1)
  81        ]
  82    );
  83
  84    buffer.update(cx, |buffer, cx| {
  85        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  86    });
  87    assert_eq!(
  88        messages(&context, cx),
  89        vec![
  90            (message_1.id, Role::User, 0..2),
  91            (message_2.id, Role::Assistant, 2..3)
  92        ]
  93    );
  94
  95    let message_3 = context.update(cx, |context, cx| {
  96        context
  97            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  98            .unwrap()
  99    });
 100    assert_eq!(
 101        messages(&context, cx),
 102        vec![
 103            (message_1.id, Role::User, 0..2),
 104            (message_2.id, Role::Assistant, 2..4),
 105            (message_3.id, Role::User, 4..4)
 106        ]
 107    );
 108
 109    let message_4 = context.update(cx, |context, cx| {
 110        context
 111            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 112            .unwrap()
 113    });
 114    assert_eq!(
 115        messages(&context, cx),
 116        vec![
 117            (message_1.id, Role::User, 0..2),
 118            (message_2.id, Role::Assistant, 2..4),
 119            (message_4.id, Role::User, 4..5),
 120            (message_3.id, Role::User, 5..5),
 121        ]
 122    );
 123
 124    buffer.update(cx, |buffer, cx| {
 125        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 126    });
 127    assert_eq!(
 128        messages(&context, cx),
 129        vec![
 130            (message_1.id, Role::User, 0..2),
 131            (message_2.id, Role::Assistant, 2..4),
 132            (message_4.id, Role::User, 4..6),
 133            (message_3.id, Role::User, 6..7),
 134        ]
 135    );
 136
 137    // Deleting across message boundaries merges the messages.
 138    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 139    assert_eq!(
 140        messages(&context, cx),
 141        vec![
 142            (message_1.id, Role::User, 0..3),
 143            (message_3.id, Role::User, 3..4),
 144        ]
 145    );
 146
 147    // Undoing the deletion should also undo the merge.
 148    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 149    assert_eq!(
 150        messages(&context, cx),
 151        vec![
 152            (message_1.id, Role::User, 0..2),
 153            (message_2.id, Role::Assistant, 2..4),
 154            (message_4.id, Role::User, 4..6),
 155            (message_3.id, Role::User, 6..7),
 156        ]
 157    );
 158
 159    // Redoing the deletion should also redo the merge.
 160    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 161    assert_eq!(
 162        messages(&context, cx),
 163        vec![
 164            (message_1.id, Role::User, 0..3),
 165            (message_3.id, Role::User, 3..4),
 166        ]
 167    );
 168
 169    // Ensure we can still insert after a merged message.
 170    let message_5 = context.update(cx, |context, cx| {
 171        context
 172            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 173            .unwrap()
 174    });
 175    assert_eq!(
 176        messages(&context, cx),
 177        vec![
 178            (message_1.id, Role::User, 0..3),
 179            (message_5.id, Role::System, 3..4),
 180            (message_3.id, Role::User, 4..5)
 181        ]
 182    );
 183}
 184
 185#[gpui::test]
 186fn test_message_splitting(cx: &mut App) {
 187    let settings_store = SettingsStore::test(cx);
 188    cx.set_global(settings_store);
 189    LanguageModelRegistry::test(cx);
 190    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 191
 192    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 193    let context = cx.new(|cx| {
 194        AssistantContext::local(
 195            registry.clone(),
 196            None,
 197            None,
 198            prompt_builder.clone(),
 199            Arc::new(SlashCommandWorkingSet::default()),
 200            Arc::new(ToolWorkingSet::default()),
 201            cx,
 202        )
 203    });
 204    let buffer = context.read(cx).buffer.clone();
 205
 206    let message_1 = context.read(cx).message_anchors[0].clone();
 207    assert_eq!(
 208        messages(&context, cx),
 209        vec![(message_1.id, Role::User, 0..0)]
 210    );
 211
 212    buffer.update(cx, |buffer, cx| {
 213        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 214    });
 215
 216    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 217    let message_2 = message_2.unwrap();
 218
 219    // We recycle newlines in the middle of a split message
 220    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 221    assert_eq!(
 222        messages(&context, cx),
 223        vec![
 224            (message_1.id, Role::User, 0..4),
 225            (message_2.id, Role::User, 4..16),
 226        ]
 227    );
 228
 229    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 230    let message_3 = message_3.unwrap();
 231
 232    // We don't recycle newlines at the end of a split message
 233    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 234    assert_eq!(
 235        messages(&context, cx),
 236        vec![
 237            (message_1.id, Role::User, 0..4),
 238            (message_3.id, Role::User, 4..5),
 239            (message_2.id, Role::User, 5..17),
 240        ]
 241    );
 242
 243    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 244    let message_4 = message_4.unwrap();
 245    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 246    assert_eq!(
 247        messages(&context, cx),
 248        vec![
 249            (message_1.id, Role::User, 0..4),
 250            (message_3.id, Role::User, 4..5),
 251            (message_2.id, Role::User, 5..9),
 252            (message_4.id, Role::User, 9..17),
 253        ]
 254    );
 255
 256    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 257    let message_5 = message_5.unwrap();
 258    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 259    assert_eq!(
 260        messages(&context, cx),
 261        vec![
 262            (message_1.id, Role::User, 0..4),
 263            (message_3.id, Role::User, 4..5),
 264            (message_2.id, Role::User, 5..9),
 265            (message_4.id, Role::User, 9..10),
 266            (message_5.id, Role::User, 10..18),
 267        ]
 268    );
 269
 270    let (message_6, message_7) =
 271        context.update(cx, |context, cx| context.split_message(14..16, cx));
 272    let message_6 = message_6.unwrap();
 273    let message_7 = message_7.unwrap();
 274    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 275    assert_eq!(
 276        messages(&context, cx),
 277        vec![
 278            (message_1.id, Role::User, 0..4),
 279            (message_3.id, Role::User, 4..5),
 280            (message_2.id, Role::User, 5..9),
 281            (message_4.id, Role::User, 9..10),
 282            (message_5.id, Role::User, 10..14),
 283            (message_6.id, Role::User, 14..17),
 284            (message_7.id, Role::User, 17..19),
 285        ]
 286    );
 287}
 288
 289#[gpui::test]
 290fn test_messages_for_offsets(cx: &mut App) {
 291    let settings_store = SettingsStore::test(cx);
 292    LanguageModelRegistry::test(cx);
 293    cx.set_global(settings_store);
 294    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 295    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 296    let context = cx.new(|cx| {
 297        AssistantContext::local(
 298            registry,
 299            None,
 300            None,
 301            prompt_builder.clone(),
 302            Arc::new(SlashCommandWorkingSet::default()),
 303            Arc::new(ToolWorkingSet::default()),
 304            cx,
 305        )
 306    });
 307    let buffer = context.read(cx).buffer.clone();
 308
 309    let message_1 = context.read(cx).message_anchors[0].clone();
 310    assert_eq!(
 311        messages(&context, cx),
 312        vec![(message_1.id, Role::User, 0..0)]
 313    );
 314
 315    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 316    let message_2 = context
 317        .update(cx, |context, cx| {
 318            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 319        })
 320        .unwrap();
 321    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 322
 323    let message_3 = context
 324        .update(cx, |context, cx| {
 325            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 326        })
 327        .unwrap();
 328    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 329
 330    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 331    assert_eq!(
 332        messages(&context, cx),
 333        vec![
 334            (message_1.id, Role::User, 0..4),
 335            (message_2.id, Role::User, 4..8),
 336            (message_3.id, Role::User, 8..11)
 337        ]
 338    );
 339
 340    assert_eq!(
 341        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 342        [message_1.id, message_2.id, message_3.id]
 343    );
 344    assert_eq!(
 345        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 346        [message_1.id, message_3.id]
 347    );
 348
 349    let message_4 = context
 350        .update(cx, |context, cx| {
 351            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 352        })
 353        .unwrap();
 354    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 355    assert_eq!(
 356        messages(&context, cx),
 357        vec![
 358            (message_1.id, Role::User, 0..4),
 359            (message_2.id, Role::User, 4..8),
 360            (message_3.id, Role::User, 8..12),
 361            (message_4.id, Role::User, 12..12)
 362        ]
 363    );
 364    assert_eq!(
 365        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 366        [message_1.id, message_2.id, message_3.id, message_4.id]
 367    );
 368
 369    fn message_ids_for_offsets(
 370        context: &Entity<AssistantContext>,
 371        offsets: &[usize],
 372        cx: &App,
 373    ) -> Vec<MessageId> {
 374        context
 375            .read(cx)
 376            .messages_for_offsets(offsets.iter().copied(), cx)
 377            .into_iter()
 378            .map(|message| message.id)
 379            .collect()
 380    }
 381}
 382
 383#[gpui::test]
 384async fn test_slash_commands(cx: &mut TestAppContext) {
 385    let settings_store = cx.update(SettingsStore::test);
 386    cx.set_global(settings_store);
 387    cx.update(LanguageModelRegistry::test);
 388    cx.update(Project::init_settings);
 389    let fs = FakeFs::new(cx.background_executor.clone());
 390
 391    fs.insert_tree(
 392        "/test",
 393        json!({
 394            "src": {
 395                "lib.rs": "fn one() -> usize { 1 }",
 396                "main.rs": "
 397                    use crate::one;
 398                    fn main() { one(); }
 399                ".unindent(),
 400            }
 401        }),
 402    )
 403    .await;
 404
 405    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 406    slash_command_registry.register_command(FileSlashCommand, false);
 407
 408    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 409    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 410    let context = cx.new(|cx| {
 411        AssistantContext::local(
 412            registry.clone(),
 413            None,
 414            None,
 415            prompt_builder.clone(),
 416            Arc::new(SlashCommandWorkingSet::default()),
 417            Arc::new(ToolWorkingSet::default()),
 418            cx,
 419        )
 420    });
 421
 422    #[derive(Default)]
 423    struct ContextRanges {
 424        parsed_commands: HashSet<Range<language::Anchor>>,
 425        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 426        output_sections: HashSet<Range<language::Anchor>>,
 427    }
 428
 429    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 430    context.update(cx, |_, cx| {
 431        cx.subscribe(&context, {
 432            let context_ranges = context_ranges.clone();
 433            move |context, _, event, _| {
 434                let mut context_ranges = context_ranges.borrow_mut();
 435                match event {
 436                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 437                        let command = context.invoked_slash_command(command_id).unwrap();
 438                        context_ranges
 439                            .command_outputs
 440                            .insert(*command_id, command.range.clone());
 441                    }
 442                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 443                        for range in removed {
 444                            context_ranges.parsed_commands.remove(range);
 445                        }
 446                        for command in updated {
 447                            context_ranges
 448                                .parsed_commands
 449                                .insert(command.source_range.clone());
 450                        }
 451                    }
 452                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 453                        context_ranges.output_sections.insert(section.range.clone());
 454                    }
 455                    _ => {}
 456                }
 457            }
 458        })
 459        .detach();
 460    });
 461
 462    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 463
 464    // Insert a slash command
 465    buffer.update(cx, |buffer, cx| {
 466        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 467    });
 468    assert_text_and_context_ranges(
 469        &buffer,
 470        &context_ranges,
 471        &"
 472        «/file src/lib.rs»"
 473            .unindent(),
 474        cx,
 475    );
 476
 477    // Edit the argument of the slash command.
 478    buffer.update(cx, |buffer, cx| {
 479        let edit_offset = buffer.text().find("lib.rs").unwrap();
 480        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 481    });
 482    assert_text_and_context_ranges(
 483        &buffer,
 484        &context_ranges,
 485        &"
 486        «/file src/main.rs»"
 487            .unindent(),
 488        cx,
 489    );
 490
 491    // Edit the name of the slash command, using one that doesn't exist.
 492    buffer.update(cx, |buffer, cx| {
 493        let edit_offset = buffer.text().find("/file").unwrap();
 494        buffer.edit(
 495            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 496            None,
 497            cx,
 498        );
 499    });
 500    assert_text_and_context_ranges(
 501        &buffer,
 502        &context_ranges,
 503        &"
 504        /unknown src/main.rs"
 505            .unindent(),
 506        cx,
 507    );
 508
 509    // Undoing the insertion of an non-existent slash command resorts the previous one.
 510    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 511    assert_text_and_context_ranges(
 512        &buffer,
 513        &context_ranges,
 514        &"
 515        «/file src/main.rs»"
 516            .unindent(),
 517        cx,
 518    );
 519
 520    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 521    context.update(cx, |context, cx| {
 522        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 523        context.insert_command_output(
 524            command_source_range,
 525            "file",
 526            Task::ready(Ok(command_output_rx.boxed())),
 527            true,
 528            cx,
 529        );
 530    });
 531    assert_text_and_context_ranges(
 532        &buffer,
 533        &context_ranges,
 534        &"
 535        ⟦«/file src/main.rs»
 536        …⟧
 537        "
 538        .unindent(),
 539        cx,
 540    );
 541
 542    command_output_tx
 543        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 544            icon: IconName::Ai,
 545            label: "src/main.rs".into(),
 546            metadata: None,
 547        }))
 548        .unwrap();
 549    command_output_tx
 550        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 551        .unwrap();
 552    cx.run_until_parked();
 553    assert_text_and_context_ranges(
 554        &buffer,
 555        &context_ranges,
 556        &"
 557        ⟦«/file src/main.rs»
 558        src/main.rs…⟧
 559        "
 560        .unindent(),
 561        cx,
 562    );
 563
 564    command_output_tx
 565        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 566        .unwrap();
 567    cx.run_until_parked();
 568    assert_text_and_context_ranges(
 569        &buffer,
 570        &context_ranges,
 571        &"
 572        ⟦«/file src/main.rs»
 573        src/main.rs
 574        fn main() {}…⟧
 575        "
 576        .unindent(),
 577        cx,
 578    );
 579
 580    command_output_tx
 581        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 582        .unwrap();
 583    cx.run_until_parked();
 584    assert_text_and_context_ranges(
 585        &buffer,
 586        &context_ranges,
 587        &"
 588        ⟦«/file src/main.rs»
 589        ⟪src/main.rs
 590        fn main() {}⟫…⟧
 591        "
 592        .unindent(),
 593        cx,
 594    );
 595
 596    drop(command_output_tx);
 597    cx.run_until_parked();
 598    assert_text_and_context_ranges(
 599        &buffer,
 600        &context_ranges,
 601        &"
 602        ⟦⟪src/main.rs
 603        fn main() {}⟫⟧
 604        "
 605        .unindent(),
 606        cx,
 607    );
 608
 609    #[track_caller]
 610    fn assert_text_and_context_ranges(
 611        buffer: &Entity<Buffer>,
 612        ranges: &RefCell<ContextRanges>,
 613        expected_marked_text: &str,
 614        cx: &mut TestAppContext,
 615    ) {
 616        let mut actual_marked_text = String::new();
 617        buffer.update(cx, |buffer, _| {
 618            struct Endpoint {
 619                offset: usize,
 620                marker: char,
 621            }
 622
 623            let ranges = ranges.borrow();
 624            let mut endpoints = Vec::new();
 625            for range in ranges.command_outputs.values() {
 626                endpoints.push(Endpoint {
 627                    offset: range.start.to_offset(buffer),
 628                    marker: '',
 629                });
 630            }
 631            for range in ranges.parsed_commands.iter() {
 632                endpoints.push(Endpoint {
 633                    offset: range.start.to_offset(buffer),
 634                    marker: '«',
 635                });
 636            }
 637            for range in ranges.output_sections.iter() {
 638                endpoints.push(Endpoint {
 639                    offset: range.start.to_offset(buffer),
 640                    marker: '',
 641                });
 642            }
 643
 644            for range in ranges.output_sections.iter() {
 645                endpoints.push(Endpoint {
 646                    offset: range.end.to_offset(buffer),
 647                    marker: '',
 648                });
 649            }
 650            for range in ranges.parsed_commands.iter() {
 651                endpoints.push(Endpoint {
 652                    offset: range.end.to_offset(buffer),
 653                    marker: '»',
 654                });
 655            }
 656            for range in ranges.command_outputs.values() {
 657                endpoints.push(Endpoint {
 658                    offset: range.end.to_offset(buffer),
 659                    marker: '',
 660                });
 661            }
 662
 663            endpoints.sort_by_key(|endpoint| endpoint.offset);
 664            let mut offset = 0;
 665            for endpoint in endpoints {
 666                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 667                actual_marked_text.push(endpoint.marker);
 668                offset = endpoint.offset;
 669            }
 670            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 671        });
 672
 673        assert_eq!(actual_marked_text, expected_marked_text);
 674    }
 675}
 676
 677#[gpui::test]
 678async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 679    cx.update(prompt_library::init);
 680    let mut settings_store = cx.update(SettingsStore::test);
 681    cx.update(|cx| {
 682        settings_store
 683            .set_user_settings(
 684                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 685                cx,
 686            )
 687            .unwrap()
 688    });
 689    cx.set_global(settings_store);
 690    cx.update(language::init);
 691    cx.update(Project::init_settings);
 692    let fs = FakeFs::new(cx.executor());
 693    let project = Project::test(fs, [Path::new("/root")], cx).await;
 694    cx.update(LanguageModelRegistry::test);
 695
 696    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 697
 698    // Create a new context
 699    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 700    let context = cx.new(|cx| {
 701        AssistantContext::local(
 702            registry.clone(),
 703            Some(project),
 704            None,
 705            prompt_builder.clone(),
 706            Arc::new(SlashCommandWorkingSet::default()),
 707            Arc::new(ToolWorkingSet::default()),
 708            cx,
 709        )
 710    });
 711
 712    // Insert an assistant message to simulate a response.
 713    let assistant_message_id = context.update(cx, |context, cx| {
 714        let user_message_id = context.messages(cx).next().unwrap().id;
 715        context
 716            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 717            .unwrap()
 718            .id
 719    });
 720
 721    // No edit tags
 722    edit(
 723        &context,
 724        "
 725
 726        «one
 727        two
 728        »",
 729        cx,
 730    );
 731    expect_patches(
 732        &context,
 733        "
 734
 735        one
 736        two
 737        ",
 738        &[],
 739        cx,
 740    );
 741
 742    // Partial edit step tag is added
 743    edit(
 744        &context,
 745        "
 746
 747        one
 748        two
 749        «
 750        <patch»",
 751        cx,
 752    );
 753    expect_patches(
 754        &context,
 755        "
 756
 757        one
 758        two
 759
 760        <patch",
 761        &[],
 762        cx,
 763    );
 764
 765    // The rest of the step tag is added. The unclosed
 766    // step is treated as incomplete.
 767    edit(
 768        &context,
 769        "
 770
 771        one
 772        two
 773
 774        <patch«>
 775        <edit>»",
 776        cx,
 777    );
 778    expect_patches(
 779        &context,
 780        "
 781
 782        one
 783        two
 784
 785        «<patch>
 786        <edit>»",
 787        &[&[]],
 788        cx,
 789    );
 790
 791    // The full patch is added
 792    edit(
 793        &context,
 794        "
 795
 796        one
 797        two
 798
 799        <patch>
 800        <edit>«
 801        <description>add a `two` function</description>
 802        <path>src/lib.rs</path>
 803        <operation>insert_after</operation>
 804        <old_text>fn one</old_text>
 805        <new_text>
 806        fn two() {}
 807        </new_text>
 808        </edit>
 809        </patch>
 810
 811        also,»",
 812        cx,
 813    );
 814    expect_patches(
 815        &context,
 816        "
 817
 818        one
 819        two
 820
 821        «<patch>
 822        <edit>
 823        <description>add a `two` function</description>
 824        <path>src/lib.rs</path>
 825        <operation>insert_after</operation>
 826        <old_text>fn one</old_text>
 827        <new_text>
 828        fn two() {}
 829        </new_text>
 830        </edit>
 831        </patch>
 832        »
 833        also,",
 834        &[&[AssistantEdit {
 835            path: "src/lib.rs".into(),
 836            kind: AssistantEditKind::InsertAfter {
 837                old_text: "fn one".into(),
 838                new_text: "fn two() {}".into(),
 839                description: Some("add a `two` function".into()),
 840            },
 841        }]],
 842        cx,
 843    );
 844
 845    // The step is manually edited.
 846    edit(
 847        &context,
 848        "
 849
 850        one
 851        two
 852
 853        <patch>
 854        <edit>
 855        <description>add a `two` function</description>
 856        <path>src/lib.rs</path>
 857        <operation>insert_after</operation>
 858        <old_text>«fn zero»</old_text>
 859        <new_text>
 860        fn two() {}
 861        </new_text>
 862        </edit>
 863        </patch>
 864
 865        also,",
 866        cx,
 867    );
 868    expect_patches(
 869        &context,
 870        "
 871
 872        one
 873        two
 874
 875        «<patch>
 876        <edit>
 877        <description>add a `two` function</description>
 878        <path>src/lib.rs</path>
 879        <operation>insert_after</operation>
 880        <old_text>fn zero</old_text>
 881        <new_text>
 882        fn two() {}
 883        </new_text>
 884        </edit>
 885        </patch>
 886        »
 887        also,",
 888        &[&[AssistantEdit {
 889            path: "src/lib.rs".into(),
 890            kind: AssistantEditKind::InsertAfter {
 891                old_text: "fn zero".into(),
 892                new_text: "fn two() {}".into(),
 893                description: Some("add a `two` function".into()),
 894            },
 895        }]],
 896        cx,
 897    );
 898
 899    // When setting the message role to User, the steps are cleared.
 900    context.update(cx, |context, cx| {
 901        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 902        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 903    });
 904    expect_patches(
 905        &context,
 906        "
 907
 908        one
 909        two
 910
 911        <patch>
 912        <edit>
 913        <description>add a `two` function</description>
 914        <path>src/lib.rs</path>
 915        <operation>insert_after</operation>
 916        <old_text>fn zero</old_text>
 917        <new_text>
 918        fn two() {}
 919        </new_text>
 920        </edit>
 921        </patch>
 922
 923        also,",
 924        &[],
 925        cx,
 926    );
 927
 928    // When setting the message role back to Assistant, the steps are reparsed.
 929    context.update(cx, |context, cx| {
 930        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 931    });
 932    expect_patches(
 933        &context,
 934        "
 935
 936        one
 937        two
 938
 939        «<patch>
 940        <edit>
 941        <description>add a `two` function</description>
 942        <path>src/lib.rs</path>
 943        <operation>insert_after</operation>
 944        <old_text>fn zero</old_text>
 945        <new_text>
 946        fn two() {}
 947        </new_text>
 948        </edit>
 949        </patch>
 950        »
 951        also,",
 952        &[&[AssistantEdit {
 953            path: "src/lib.rs".into(),
 954            kind: AssistantEditKind::InsertAfter {
 955                old_text: "fn zero".into(),
 956                new_text: "fn two() {}".into(),
 957                description: Some("add a `two` function".into()),
 958            },
 959        }]],
 960        cx,
 961    );
 962
 963    // Ensure steps are re-parsed when deserializing.
 964    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 965    let deserialized_context = cx.new(|cx| {
 966        AssistantContext::deserialize(
 967            serialized_context,
 968            Default::default(),
 969            registry.clone(),
 970            prompt_builder.clone(),
 971            Arc::new(SlashCommandWorkingSet::default()),
 972            Arc::new(ToolWorkingSet::default()),
 973            None,
 974            None,
 975            cx,
 976        )
 977    });
 978    expect_patches(
 979        &deserialized_context,
 980        "
 981
 982        one
 983        two
 984
 985        «<patch>
 986        <edit>
 987        <description>add a `two` function</description>
 988        <path>src/lib.rs</path>
 989        <operation>insert_after</operation>
 990        <old_text>fn zero</old_text>
 991        <new_text>
 992        fn two() {}
 993        </new_text>
 994        </edit>
 995        </patch>
 996        »
 997        also,",
 998        &[&[AssistantEdit {
 999            path: "src/lib.rs".into(),
1000            kind: AssistantEditKind::InsertAfter {
1001                old_text: "fn zero".into(),
1002                new_text: "fn two() {}".into(),
1003                description: Some("add a `two` function".into()),
1004            },
1005        }]],
1006        cx,
1007    );
1008
1009    fn edit(
1010        context: &Entity<AssistantContext>,
1011        new_text_marked_with_edits: &str,
1012        cx: &mut TestAppContext,
1013    ) {
1014        context.update(cx, |context, cx| {
1015            context.buffer.update(cx, |buffer, cx| {
1016                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1017            });
1018        });
1019        cx.executor().run_until_parked();
1020    }
1021
1022    #[track_caller]
1023    fn expect_patches(
1024        context: &Entity<AssistantContext>,
1025        expected_marked_text: &str,
1026        expected_suggestions: &[&[AssistantEdit]],
1027        cx: &mut TestAppContext,
1028    ) {
1029        let expected_marked_text = expected_marked_text.unindent();
1030        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1031
1032        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1033            context.buffer.read_with(cx, |buffer, _| {
1034                let ranges = context
1035                    .patches
1036                    .iter()
1037                    .map(|entry| entry.range.to_offset(buffer))
1038                    .collect::<Vec<_>>();
1039                (
1040                    buffer.text(),
1041                    ranges,
1042                    context
1043                        .patches
1044                        .iter()
1045                        .map(|step| step.edits.clone())
1046                        .collect::<Vec<_>>(),
1047                )
1048            })
1049        });
1050
1051        assert_eq!(buffer_text, expected_text);
1052
1053        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1054        assert_eq!(actual_marked_text, expected_marked_text);
1055
1056        assert_eq!(
1057            patches
1058                .iter()
1059                .map(|patch| {
1060                    patch
1061                        .iter()
1062                        .map(|edit| {
1063                            let edit = edit.as_ref().unwrap();
1064                            AssistantEdit {
1065                                path: edit.path.clone(),
1066                                kind: edit.kind.clone(),
1067                            }
1068                        })
1069                        .collect::<Vec<_>>()
1070                })
1071                .collect::<Vec<_>>(),
1072            expected_suggestions
1073        );
1074    }
1075}
1076
1077#[gpui::test]
1078async fn test_serialization(cx: &mut TestAppContext) {
1079    let settings_store = cx.update(SettingsStore::test);
1080    cx.set_global(settings_store);
1081    cx.update(LanguageModelRegistry::test);
1082    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1083    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1084    let context = cx.new(|cx| {
1085        AssistantContext::local(
1086            registry.clone(),
1087            None,
1088            None,
1089            prompt_builder.clone(),
1090            Arc::new(SlashCommandWorkingSet::default()),
1091            Arc::new(ToolWorkingSet::default()),
1092            cx,
1093        )
1094    });
1095    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1096    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1097    let message_1 = context.update(cx, |context, cx| {
1098        context
1099            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1100            .unwrap()
1101    });
1102    let message_2 = context.update(cx, |context, cx| {
1103        context
1104            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1105            .unwrap()
1106    });
1107    buffer.update(cx, |buffer, cx| {
1108        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1109        buffer.finalize_last_transaction();
1110    });
1111    let _message_3 = context.update(cx, |context, cx| {
1112        context
1113            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1114            .unwrap()
1115    });
1116    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1117    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1118    assert_eq!(
1119        cx.read(|cx| messages(&context, cx)),
1120        [
1121            (message_0, Role::User, 0..2),
1122            (message_1.id, Role::Assistant, 2..6),
1123            (message_2.id, Role::System, 6..6),
1124        ]
1125    );
1126
1127    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1128    let deserialized_context = cx.new(|cx| {
1129        AssistantContext::deserialize(
1130            serialized_context,
1131            Default::default(),
1132            registry.clone(),
1133            prompt_builder.clone(),
1134            Arc::new(SlashCommandWorkingSet::default()),
1135            Arc::new(ToolWorkingSet::default()),
1136            None,
1137            None,
1138            cx,
1139        )
1140    });
1141    let deserialized_buffer =
1142        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1143    assert_eq!(
1144        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1145        "a\nb\nc\n"
1146    );
1147    assert_eq!(
1148        cx.read(|cx| messages(&deserialized_context, cx)),
1149        [
1150            (message_0, Role::User, 0..2),
1151            (message_1.id, Role::Assistant, 2..6),
1152            (message_2.id, Role::System, 6..6),
1153        ]
1154    );
1155}
1156
1157#[gpui::test(iterations = 100)]
1158async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1159    let min_peers = env::var("MIN_PEERS")
1160        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1161        .unwrap_or(2);
1162    let max_peers = env::var("MAX_PEERS")
1163        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1164        .unwrap_or(5);
1165    let operations = env::var("OPERATIONS")
1166        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1167        .unwrap_or(50);
1168
1169    let settings_store = cx.update(SettingsStore::test);
1170    cx.set_global(settings_store);
1171    cx.update(LanguageModelRegistry::test);
1172
1173    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1174    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1175    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1176    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1177
1178    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1179    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1180    let mut contexts = Vec::new();
1181
1182    let num_peers = rng.gen_range(min_peers..=max_peers);
1183    let context_id = ContextId::new();
1184    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1185    for i in 0..num_peers {
1186        let context = cx.new(|cx| {
1187            AssistantContext::new(
1188                context_id.clone(),
1189                i as ReplicaId,
1190                language::Capability::ReadWrite,
1191                registry.clone(),
1192                prompt_builder.clone(),
1193                Arc::new(SlashCommandWorkingSet::default()),
1194                Arc::new(ToolWorkingSet::default()),
1195                None,
1196                None,
1197                cx,
1198            )
1199        });
1200
1201        cx.update(|cx| {
1202            cx.subscribe(&context, {
1203                let network = network.clone();
1204                move |_, event, _| {
1205                    if let ContextEvent::Operation(op) = event {
1206                        network
1207                            .lock()
1208                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1209                    }
1210                }
1211            })
1212            .detach();
1213        });
1214
1215        contexts.push(context);
1216        network.lock().add_peer(i as ReplicaId);
1217    }
1218
1219    let mut mutation_count = operations;
1220
1221    while mutation_count > 0
1222        || !network.lock().is_idle()
1223        || network.lock().contains_disconnected_peers()
1224    {
1225        let context_index = rng.gen_range(0..contexts.len());
1226        let context = &contexts[context_index];
1227
1228        match rng.gen_range(0..100) {
1229            0..=29 if mutation_count > 0 => {
1230                log::info!("Context {}: edit buffer", context_index);
1231                context.update(cx, |context, cx| {
1232                    context
1233                        .buffer
1234                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1235                });
1236                mutation_count -= 1;
1237            }
1238            30..=44 if mutation_count > 0 => {
1239                context.update(cx, |context, cx| {
1240                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1241                    log::info!("Context {}: split message at {:?}", context_index, range);
1242                    context.split_message(range, cx);
1243                });
1244                mutation_count -= 1;
1245            }
1246            45..=59 if mutation_count > 0 => {
1247                context.update(cx, |context, cx| {
1248                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1249                        let role = *[Role::User, Role::Assistant, Role::System]
1250                            .choose(&mut rng)
1251                            .unwrap();
1252                        log::info!(
1253                            "Context {}: insert message after {:?} with {:?}",
1254                            context_index,
1255                            message.id,
1256                            role
1257                        );
1258                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1259                    }
1260                });
1261                mutation_count -= 1;
1262            }
1263            60..=74 if mutation_count > 0 => {
1264                context.update(cx, |context, cx| {
1265                    let command_text = "/".to_string()
1266                        + slash_commands
1267                            .command_names()
1268                            .choose(&mut rng)
1269                            .unwrap()
1270                            .clone()
1271                            .as_ref();
1272
1273                    let command_range = context.buffer.update(cx, |buffer, cx| {
1274                        let offset = buffer.random_byte_range(0, &mut rng).start;
1275                        buffer.edit(
1276                            [(offset..offset, format!("\n{}\n", command_text))],
1277                            None,
1278                            cx,
1279                        );
1280                        offset + 1..offset + 1 + command_text.len()
1281                    });
1282
1283                    let output_text = RandomCharIter::new(&mut rng)
1284                        .filter(|c| *c != '\r')
1285                        .take(10)
1286                        .collect::<String>();
1287
1288                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1289                        role: Role::User,
1290                        merge_same_roles: true,
1291                    })];
1292
1293                    let num_sections = rng.gen_range(0..=3);
1294                    let mut section_start = 0;
1295                    for _ in 0..num_sections {
1296                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1297                        while !output_text.is_char_boundary(section_end) {
1298                            section_end += 1;
1299                        }
1300                        events.push(Ok(SlashCommandEvent::StartSection {
1301                            icon: IconName::Ai,
1302                            label: "section".into(),
1303                            metadata: None,
1304                        }));
1305                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1306                            text: output_text[section_start..section_end].to_string(),
1307                            run_commands_in_text: false,
1308                        })));
1309                        events.push(Ok(SlashCommandEvent::EndSection));
1310                        section_start = section_end;
1311                    }
1312
1313                    if section_start < output_text.len() {
1314                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1315                            text: output_text[section_start..].to_string(),
1316                            run_commands_in_text: false,
1317                        })));
1318                    }
1319
1320                    log::info!(
1321                        "Context {}: insert slash command output at {:?} with {:?} events",
1322                        context_index,
1323                        command_range,
1324                        events.len()
1325                    );
1326
1327                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1328                        ..context.buffer.read(cx).anchor_after(command_range.end);
1329                    context.insert_command_output(
1330                        command_range,
1331                        "/command",
1332                        Task::ready(Ok(stream::iter(events).boxed())),
1333                        true,
1334                        cx,
1335                    );
1336                });
1337                cx.run_until_parked();
1338                mutation_count -= 1;
1339            }
1340            75..=84 if mutation_count > 0 => {
1341                context.update(cx, |context, cx| {
1342                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1343                        let new_status = match rng.gen_range(0..3) {
1344                            0 => MessageStatus::Done,
1345                            1 => MessageStatus::Pending,
1346                            _ => MessageStatus::Error(SharedString::from("Random error")),
1347                        };
1348                        log::info!(
1349                            "Context {}: update message {:?} status to {:?}",
1350                            context_index,
1351                            message.id,
1352                            new_status
1353                        );
1354                        context.update_metadata(message.id, cx, |metadata| {
1355                            metadata.status = new_status;
1356                        });
1357                    }
1358                });
1359                mutation_count -= 1;
1360            }
1361            _ => {
1362                let replica_id = context_index as ReplicaId;
1363                if network.lock().is_disconnected(replica_id) {
1364                    network.lock().reconnect_peer(replica_id, 0);
1365
1366                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1367                        let host_context = &contexts[0].read(cx);
1368                        let guest_context = context.read(cx);
1369                        (
1370                            guest_context.serialize_ops(&host_context.version(cx), cx),
1371                            host_context.serialize_ops(&guest_context.version(cx), cx),
1372                        )
1373                    });
1374                    let ops_to_send = ops_to_send.await;
1375                    let ops_to_receive = ops_to_receive
1376                        .await
1377                        .into_iter()
1378                        .map(ContextOperation::from_proto)
1379                        .collect::<Result<Vec<_>>>()
1380                        .unwrap();
1381                    log::info!(
1382                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1383                        context_index,
1384                        ops_to_send.len(),
1385                        ops_to_receive.len()
1386                    );
1387
1388                    network.lock().broadcast(replica_id, ops_to_send);
1389                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1390                } else if rng.gen_bool(0.1) && replica_id != 0 {
1391                    log::info!("Context {}: disconnecting", context_index);
1392                    network.lock().disconnect_peer(replica_id);
1393                } else if network.lock().has_unreceived(replica_id) {
1394                    log::info!("Context {}: applying operations", context_index);
1395                    let ops = network.lock().receive(replica_id);
1396                    let ops = ops
1397                        .into_iter()
1398                        .map(ContextOperation::from_proto)
1399                        .collect::<Result<Vec<_>>>()
1400                        .unwrap();
1401                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1402                }
1403            }
1404        }
1405    }
1406
1407    cx.read(|cx| {
1408        let first_context = contexts[0].read(cx);
1409        for context in &contexts[1..] {
1410            let context = context.read(cx);
1411            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1412            assert_eq!(
1413                context.buffer.read(cx).text(),
1414                first_context.buffer.read(cx).text(),
1415                "Context {} text != Context 0 text",
1416                context.buffer.read(cx).replica_id()
1417            );
1418            assert_eq!(
1419                context.message_anchors,
1420                first_context.message_anchors,
1421                "Context {} messages != Context 0 messages",
1422                context.buffer.read(cx).replica_id()
1423            );
1424            assert_eq!(
1425                context.messages_metadata,
1426                first_context.messages_metadata,
1427                "Context {} message metadata != Context 0 message metadata",
1428                context.buffer.read(cx).replica_id()
1429            );
1430            assert_eq!(
1431                context.slash_command_output_sections,
1432                first_context.slash_command_output_sections,
1433                "Context {} slash command output sections != Context 0 slash command output sections",
1434                context.buffer.read(cx).replica_id()
1435            );
1436        }
1437    });
1438}
1439
1440#[gpui::test]
1441fn test_mark_cache_anchors(cx: &mut App) {
1442    let settings_store = SettingsStore::test(cx);
1443    LanguageModelRegistry::test(cx);
1444    cx.set_global(settings_store);
1445    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1446    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1447    let context = cx.new(|cx| {
1448        AssistantContext::local(
1449            registry,
1450            None,
1451            None,
1452            prompt_builder.clone(),
1453            Arc::new(SlashCommandWorkingSet::default()),
1454            Arc::new(ToolWorkingSet::default()),
1455            cx,
1456        )
1457    });
1458    let buffer = context.read(cx).buffer.clone();
1459
1460    // Create a test cache configuration
1461    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1462        max_cache_anchors: 3,
1463        should_speculate: true,
1464        min_total_token: 10,
1465    });
1466
1467    let message_1 = context.read(cx).message_anchors[0].clone();
1468
1469    context.update(cx, |context, cx| {
1470        context.mark_cache_anchors(cache_configuration, false, cx)
1471    });
1472
1473    assert_eq!(
1474        messages_cache(&context, cx)
1475            .iter()
1476            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1477            .count(),
1478        0,
1479        "Empty messages should not have any cache anchors."
1480    );
1481
1482    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1483    let message_2 = context
1484        .update(cx, |context, cx| {
1485            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1486        })
1487        .unwrap();
1488
1489    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1490    let message_3 = context
1491        .update(cx, |context, cx| {
1492            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1493        })
1494        .unwrap();
1495    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1496
1497    context.update(cx, |context, cx| {
1498        context.mark_cache_anchors(cache_configuration, false, cx)
1499    });
1500    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1501    assert_eq!(
1502        messages_cache(&context, cx)
1503            .iter()
1504            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1505            .count(),
1506        0,
1507        "Messages should not be marked for cache before going over the token minimum."
1508    );
1509    context.update(cx, |context, _| {
1510        context.token_count = Some(20);
1511    });
1512
1513    context.update(cx, |context, cx| {
1514        context.mark_cache_anchors(cache_configuration, true, cx)
1515    });
1516    assert_eq!(
1517        messages_cache(&context, cx)
1518            .iter()
1519            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1520            .collect::<Vec<bool>>(),
1521        vec![true, true, false],
1522        "Last message should not be an anchor on speculative request."
1523    );
1524
1525    context
1526        .update(cx, |context, cx| {
1527            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1528        })
1529        .unwrap();
1530
1531    context.update(cx, |context, cx| {
1532        context.mark_cache_anchors(cache_configuration, false, cx)
1533    });
1534    assert_eq!(
1535        messages_cache(&context, cx)
1536            .iter()
1537            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1538            .collect::<Vec<bool>>(),
1539        vec![false, true, true, false],
1540        "Most recent message should also be cached if not a speculative request."
1541    );
1542    context.update(cx, |context, cx| {
1543        context.update_cache_status_for_completion(cx)
1544    });
1545    assert_eq!(
1546        messages_cache(&context, cx)
1547            .iter()
1548            .map(|(_, cache)| cache
1549                .as_ref()
1550                .map_or(None, |cache| Some(cache.status.clone())))
1551            .collect::<Vec<Option<CacheStatus>>>(),
1552        vec![
1553            Some(CacheStatus::Cached),
1554            Some(CacheStatus::Cached),
1555            Some(CacheStatus::Cached),
1556            None
1557        ],
1558        "All user messages prior to anchor should be marked as cached."
1559    );
1560
1561    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1562    context.update(cx, |context, cx| {
1563        context.mark_cache_anchors(cache_configuration, false, cx)
1564    });
1565    assert_eq!(
1566        messages_cache(&context, cx)
1567            .iter()
1568            .map(|(_, cache)| cache
1569                .as_ref()
1570                .map_or(None, |cache| Some(cache.status.clone())))
1571            .collect::<Vec<Option<CacheStatus>>>(),
1572        vec![
1573            Some(CacheStatus::Cached),
1574            Some(CacheStatus::Cached),
1575            Some(CacheStatus::Pending),
1576            None
1577        ],
1578        "Modifying a message should invalidate it's cache but leave previous messages."
1579    );
1580    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1581    context.update(cx, |context, cx| {
1582        context.mark_cache_anchors(cache_configuration, false, cx)
1583    });
1584    assert_eq!(
1585        messages_cache(&context, cx)
1586            .iter()
1587            .map(|(_, cache)| cache
1588                .as_ref()
1589                .map_or(None, |cache| Some(cache.status.clone())))
1590            .collect::<Vec<Option<CacheStatus>>>(),
1591        vec![
1592            Some(CacheStatus::Pending),
1593            Some(CacheStatus::Pending),
1594            Some(CacheStatus::Pending),
1595            None
1596        ],
1597        "Modifying a message should invalidate all future messages."
1598    );
1599}
1600
1601fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1602    context
1603        .read(cx)
1604        .messages(cx)
1605        .map(|message| (message.id, message.role, message.offset_range))
1606        .collect()
1607}
1608
1609fn messages_cache(
1610    context: &Entity<AssistantContext>,
1611    cx: &App,
1612) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1613    context
1614        .read(cx)
1615        .messages(cx)
1616        .map(|message| (message.id, message.cache.clone()))
1617        .collect()
1618}
1619
1620#[derive(Clone)]
1621struct FakeSlashCommand(String);
1622
1623impl SlashCommand for FakeSlashCommand {
1624    fn name(&self) -> String {
1625        self.0.clone()
1626    }
1627
1628    fn description(&self) -> String {
1629        format!("Fake slash command: {}", self.0)
1630    }
1631
1632    fn menu_text(&self) -> String {
1633        format!("Run fake command: {}", self.0)
1634    }
1635
1636    fn complete_argument(
1637        self: Arc<Self>,
1638        _arguments: &[String],
1639        _cancel: Arc<AtomicBool>,
1640        _workspace: Option<WeakEntity<Workspace>>,
1641        _window: &mut Window,
1642        _cx: &mut App,
1643    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1644        Task::ready(Ok(vec![]))
1645    }
1646
1647    fn requires_argument(&self) -> bool {
1648        false
1649    }
1650
1651    fn run(
1652        self: Arc<Self>,
1653        _arguments: &[String],
1654        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1655        _context_buffer: BufferSnapshot,
1656        _workspace: WeakEntity<Workspace>,
1657        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1658        _window: &mut Window,
1659        _cx: &mut App,
1660    ) -> Task<SlashCommandResult> {
1661        Task::ready(Ok(SlashCommandOutput {
1662            text: format!("Executed fake command: {}", self.0),
1663            sections: vec![],
1664            run_commands_in_text: false,
1665        }
1666        .to_event_stream()))
1667    }
1668}