context_tests.rs

   1use super::{AssistantEdit, MessageCacheMetadata};
   2use crate::{
   3    assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
   4    Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
   5    SlashCommandId,
   6};
   7use anyhow::Result;
   8use assistant_slash_command::{
   9    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
  10    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult,
  11};
  12use collections::{HashMap, HashSet};
  13use fs::FakeFs;
  14use futures::{
  15    channel::mpsc,
  16    stream::{self, StreamExt},
  17};
  18use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  19use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  20use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  21use parking_lot::Mutex;
  22use pretty_assertions::assert_eq;
  23use project::Project;
  24use rand::prelude::*;
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{
  28    cell::RefCell,
  29    env,
  30    ops::Range,
  31    path::Path,
  32    rc::Rc,
  33    sync::{atomic::AtomicBool, Arc},
  34};
  35use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
  36use ui::{Context as _, IconName, WindowContext};
  37use unindent::Unindent;
  38use util::{
  39    test::{generate_marked_text, marked_text_ranges},
  40    RandomCharIter,
  41};
  42use workspace::Workspace;
  43
  44#[gpui::test]
  45fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  46    let settings_store = SettingsStore::test(cx);
  47    LanguageModelRegistry::test(cx);
  48    cx.set_global(settings_store);
  49    assistant_panel::init(cx);
  50    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  51    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  52    let context =
  53        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
  54    let buffer = context.read(cx).buffer.clone();
  55
  56    let message_1 = context.read(cx).message_anchors[0].clone();
  57    assert_eq!(
  58        messages(&context, cx),
  59        vec![(message_1.id, Role::User, 0..0)]
  60    );
  61
  62    let message_2 = context.update(cx, |context, cx| {
  63        context
  64            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  65            .unwrap()
  66    });
  67    assert_eq!(
  68        messages(&context, cx),
  69        vec![
  70            (message_1.id, Role::User, 0..1),
  71            (message_2.id, Role::Assistant, 1..1)
  72        ]
  73    );
  74
  75    buffer.update(cx, |buffer, cx| {
  76        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  77    });
  78    assert_eq!(
  79        messages(&context, cx),
  80        vec![
  81            (message_1.id, Role::User, 0..2),
  82            (message_2.id, Role::Assistant, 2..3)
  83        ]
  84    );
  85
  86    let message_3 = context.update(cx, |context, cx| {
  87        context
  88            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  89            .unwrap()
  90    });
  91    assert_eq!(
  92        messages(&context, cx),
  93        vec![
  94            (message_1.id, Role::User, 0..2),
  95            (message_2.id, Role::Assistant, 2..4),
  96            (message_3.id, Role::User, 4..4)
  97        ]
  98    );
  99
 100    let message_4 = context.update(cx, |context, cx| {
 101        context
 102            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 103            .unwrap()
 104    });
 105    assert_eq!(
 106        messages(&context, cx),
 107        vec![
 108            (message_1.id, Role::User, 0..2),
 109            (message_2.id, Role::Assistant, 2..4),
 110            (message_4.id, Role::User, 4..5),
 111            (message_3.id, Role::User, 5..5),
 112        ]
 113    );
 114
 115    buffer.update(cx, |buffer, cx| {
 116        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 117    });
 118    assert_eq!(
 119        messages(&context, cx),
 120        vec![
 121            (message_1.id, Role::User, 0..2),
 122            (message_2.id, Role::Assistant, 2..4),
 123            (message_4.id, Role::User, 4..6),
 124            (message_3.id, Role::User, 6..7),
 125        ]
 126    );
 127
 128    // Deleting across message boundaries merges the messages.
 129    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 130    assert_eq!(
 131        messages(&context, cx),
 132        vec![
 133            (message_1.id, Role::User, 0..3),
 134            (message_3.id, Role::User, 3..4),
 135        ]
 136    );
 137
 138    // Undoing the deletion should also undo the merge.
 139    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 140    assert_eq!(
 141        messages(&context, cx),
 142        vec![
 143            (message_1.id, Role::User, 0..2),
 144            (message_2.id, Role::Assistant, 2..4),
 145            (message_4.id, Role::User, 4..6),
 146            (message_3.id, Role::User, 6..7),
 147        ]
 148    );
 149
 150    // Redoing the deletion should also redo the merge.
 151    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 152    assert_eq!(
 153        messages(&context, cx),
 154        vec![
 155            (message_1.id, Role::User, 0..3),
 156            (message_3.id, Role::User, 3..4),
 157        ]
 158    );
 159
 160    // Ensure we can still insert after a merged message.
 161    let message_5 = context.update(cx, |context, cx| {
 162        context
 163            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 164            .unwrap()
 165    });
 166    assert_eq!(
 167        messages(&context, cx),
 168        vec![
 169            (message_1.id, Role::User, 0..3),
 170            (message_5.id, Role::System, 3..4),
 171            (message_3.id, Role::User, 4..5)
 172        ]
 173    );
 174}
 175
 176#[gpui::test]
 177fn test_message_splitting(cx: &mut AppContext) {
 178    let settings_store = SettingsStore::test(cx);
 179    cx.set_global(settings_store);
 180    LanguageModelRegistry::test(cx);
 181    assistant_panel::init(cx);
 182    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 183
 184    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 185    let context =
 186        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 187    let buffer = context.read(cx).buffer.clone();
 188
 189    let message_1 = context.read(cx).message_anchors[0].clone();
 190    assert_eq!(
 191        messages(&context, cx),
 192        vec![(message_1.id, Role::User, 0..0)]
 193    );
 194
 195    buffer.update(cx, |buffer, cx| {
 196        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 197    });
 198
 199    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 200    let message_2 = message_2.unwrap();
 201
 202    // We recycle newlines in the middle of a split message
 203    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 204    assert_eq!(
 205        messages(&context, cx),
 206        vec![
 207            (message_1.id, Role::User, 0..4),
 208            (message_2.id, Role::User, 4..16),
 209        ]
 210    );
 211
 212    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 213    let message_3 = message_3.unwrap();
 214
 215    // We don't recycle newlines at the end of a split message
 216    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 217    assert_eq!(
 218        messages(&context, cx),
 219        vec![
 220            (message_1.id, Role::User, 0..4),
 221            (message_3.id, Role::User, 4..5),
 222            (message_2.id, Role::User, 5..17),
 223        ]
 224    );
 225
 226    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 227    let message_4 = message_4.unwrap();
 228    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 229    assert_eq!(
 230        messages(&context, cx),
 231        vec![
 232            (message_1.id, Role::User, 0..4),
 233            (message_3.id, Role::User, 4..5),
 234            (message_2.id, Role::User, 5..9),
 235            (message_4.id, Role::User, 9..17),
 236        ]
 237    );
 238
 239    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 240    let message_5 = message_5.unwrap();
 241    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 242    assert_eq!(
 243        messages(&context, cx),
 244        vec![
 245            (message_1.id, Role::User, 0..4),
 246            (message_3.id, Role::User, 4..5),
 247            (message_2.id, Role::User, 5..9),
 248            (message_4.id, Role::User, 9..10),
 249            (message_5.id, Role::User, 10..18),
 250        ]
 251    );
 252
 253    let (message_6, message_7) =
 254        context.update(cx, |context, cx| context.split_message(14..16, cx));
 255    let message_6 = message_6.unwrap();
 256    let message_7 = message_7.unwrap();
 257    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 258    assert_eq!(
 259        messages(&context, cx),
 260        vec![
 261            (message_1.id, Role::User, 0..4),
 262            (message_3.id, Role::User, 4..5),
 263            (message_2.id, Role::User, 5..9),
 264            (message_4.id, Role::User, 9..10),
 265            (message_5.id, Role::User, 10..14),
 266            (message_6.id, Role::User, 14..17),
 267            (message_7.id, Role::User, 17..19),
 268        ]
 269    );
 270}
 271
 272#[gpui::test]
 273fn test_messages_for_offsets(cx: &mut AppContext) {
 274    let settings_store = SettingsStore::test(cx);
 275    LanguageModelRegistry::test(cx);
 276    cx.set_global(settings_store);
 277    assistant_panel::init(cx);
 278    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 279    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 280    let context =
 281        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 282    let buffer = context.read(cx).buffer.clone();
 283
 284    let message_1 = context.read(cx).message_anchors[0].clone();
 285    assert_eq!(
 286        messages(&context, cx),
 287        vec![(message_1.id, Role::User, 0..0)]
 288    );
 289
 290    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 291    let message_2 = context
 292        .update(cx, |context, cx| {
 293            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 294        })
 295        .unwrap();
 296    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 297
 298    let message_3 = context
 299        .update(cx, |context, cx| {
 300            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 301        })
 302        .unwrap();
 303    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 304
 305    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 306    assert_eq!(
 307        messages(&context, cx),
 308        vec![
 309            (message_1.id, Role::User, 0..4),
 310            (message_2.id, Role::User, 4..8),
 311            (message_3.id, Role::User, 8..11)
 312        ]
 313    );
 314
 315    assert_eq!(
 316        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 317        [message_1.id, message_2.id, message_3.id]
 318    );
 319    assert_eq!(
 320        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 321        [message_1.id, message_3.id]
 322    );
 323
 324    let message_4 = context
 325        .update(cx, |context, cx| {
 326            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 327        })
 328        .unwrap();
 329    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 330    assert_eq!(
 331        messages(&context, cx),
 332        vec![
 333            (message_1.id, Role::User, 0..4),
 334            (message_2.id, Role::User, 4..8),
 335            (message_3.id, Role::User, 8..12),
 336            (message_4.id, Role::User, 12..12)
 337        ]
 338    );
 339    assert_eq!(
 340        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 341        [message_1.id, message_2.id, message_3.id, message_4.id]
 342    );
 343
 344    fn message_ids_for_offsets(
 345        context: &Model<Context>,
 346        offsets: &[usize],
 347        cx: &AppContext,
 348    ) -> Vec<MessageId> {
 349        context
 350            .read(cx)
 351            .messages_for_offsets(offsets.iter().copied(), cx)
 352            .into_iter()
 353            .map(|message| message.id)
 354            .collect()
 355    }
 356}
 357
 358#[gpui::test]
 359async fn test_slash_commands(cx: &mut TestAppContext) {
 360    let settings_store = cx.update(SettingsStore::test);
 361    cx.set_global(settings_store);
 362    cx.update(LanguageModelRegistry::test);
 363    cx.update(Project::init_settings);
 364    cx.update(assistant_panel::init);
 365    let fs = FakeFs::new(cx.background_executor.clone());
 366
 367    fs.insert_tree(
 368        "/test",
 369        json!({
 370            "src": {
 371                "lib.rs": "fn one() -> usize { 1 }",
 372                "main.rs": "
 373                    use crate::one;
 374                    fn main() { one(); }
 375                ".unindent(),
 376            }
 377        }),
 378    )
 379    .await;
 380
 381    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 382    slash_command_registry.register_command(file_command::FileSlashCommand, false);
 383
 384    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 385    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 386    let context =
 387        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 388
 389    #[derive(Default)]
 390    struct ContextRanges {
 391        parsed_commands: HashSet<Range<language::Anchor>>,
 392        command_outputs: HashMap<SlashCommandId, Range<language::Anchor>>,
 393        output_sections: HashSet<Range<language::Anchor>>,
 394    }
 395
 396    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 397    context.update(cx, |_, cx| {
 398        cx.subscribe(&context, {
 399            let context_ranges = context_ranges.clone();
 400            move |context, _, event, _| {
 401                let mut context_ranges = context_ranges.borrow_mut();
 402                match event {
 403                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 404                        let command = context.invoked_slash_command(command_id).unwrap();
 405                        context_ranges
 406                            .command_outputs
 407                            .insert(*command_id, command.range.clone());
 408                    }
 409                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 410                        for range in removed {
 411                            context_ranges.parsed_commands.remove(range);
 412                        }
 413                        for command in updated {
 414                            context_ranges
 415                                .parsed_commands
 416                                .insert(command.source_range.clone());
 417                        }
 418                    }
 419                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 420                        context_ranges.output_sections.insert(section.range.clone());
 421                    }
 422                    _ => {}
 423                }
 424            }
 425        })
 426        .detach();
 427    });
 428
 429    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 430
 431    // Insert a slash command
 432    buffer.update(cx, |buffer, cx| {
 433        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 434    });
 435    assert_text_and_context_ranges(
 436        &buffer,
 437        &context_ranges,
 438        &"
 439        «/file src/lib.rs»"
 440            .unindent(),
 441        cx,
 442    );
 443
 444    // Edit the argument of the slash command.
 445    buffer.update(cx, |buffer, cx| {
 446        let edit_offset = buffer.text().find("lib.rs").unwrap();
 447        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 448    });
 449    assert_text_and_context_ranges(
 450        &buffer,
 451        &context_ranges,
 452        &"
 453        «/file src/main.rs»"
 454            .unindent(),
 455        cx,
 456    );
 457
 458    // Edit the name of the slash command, using one that doesn't exist.
 459    buffer.update(cx, |buffer, cx| {
 460        let edit_offset = buffer.text().find("/file").unwrap();
 461        buffer.edit(
 462            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 463            None,
 464            cx,
 465        );
 466    });
 467    assert_text_and_context_ranges(
 468        &buffer,
 469        &context_ranges,
 470        &"
 471        /unknown src/main.rs"
 472            .unindent(),
 473        cx,
 474    );
 475
 476    // Undoing the insertion of an non-existent slash command resorts the previous one.
 477    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 478    assert_text_and_context_ranges(
 479        &buffer,
 480        &context_ranges,
 481        &"
 482        «/file src/main.rs»"
 483            .unindent(),
 484        cx,
 485    );
 486
 487    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 488    context.update(cx, |context, cx| {
 489        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 490        context.insert_command_output(
 491            command_source_range,
 492            "file",
 493            Task::ready(Ok(command_output_rx.boxed())),
 494            true,
 495            cx,
 496        );
 497    });
 498    assert_text_and_context_ranges(
 499        &buffer,
 500        &context_ranges,
 501        &"
 502        ⟦«/file src/main.rs»
 503        …⟧
 504        "
 505        .unindent(),
 506        cx,
 507    );
 508
 509    command_output_tx
 510        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 511            icon: IconName::Ai,
 512            label: "src/main.rs".into(),
 513            metadata: None,
 514        }))
 515        .unwrap();
 516    command_output_tx
 517        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 518        .unwrap();
 519    cx.run_until_parked();
 520    assert_text_and_context_ranges(
 521        &buffer,
 522        &context_ranges,
 523        &"
 524        ⟦«/file src/main.rs»
 525        src/main.rs…⟧
 526        "
 527        .unindent(),
 528        cx,
 529    );
 530
 531    command_output_tx
 532        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 533        .unwrap();
 534    cx.run_until_parked();
 535    assert_text_and_context_ranges(
 536        &buffer,
 537        &context_ranges,
 538        &"
 539        ⟦«/file src/main.rs»
 540        src/main.rs
 541        fn main() {}…⟧
 542        "
 543        .unindent(),
 544        cx,
 545    );
 546
 547    command_output_tx
 548        .unbounded_send(Ok(SlashCommandEvent::EndSection { metadata: None }))
 549        .unwrap();
 550    cx.run_until_parked();
 551    assert_text_and_context_ranges(
 552        &buffer,
 553        &context_ranges,
 554        &"
 555        ⟦«/file src/main.rs»
 556        ⟪src/main.rs
 557        fn main() {}⟫…⟧
 558        "
 559        .unindent(),
 560        cx,
 561    );
 562
 563    drop(command_output_tx);
 564    cx.run_until_parked();
 565    assert_text_and_context_ranges(
 566        &buffer,
 567        &context_ranges,
 568        &"
 569        ⟦⟪src/main.rs
 570        fn main() {}⟫⟧
 571        "
 572        .unindent(),
 573        cx,
 574    );
 575
 576    #[track_caller]
 577    fn assert_text_and_context_ranges(
 578        buffer: &Model<Buffer>,
 579        ranges: &RefCell<ContextRanges>,
 580        expected_marked_text: &str,
 581        cx: &mut TestAppContext,
 582    ) {
 583        let mut actual_marked_text = String::new();
 584        buffer.update(cx, |buffer, _| {
 585            struct Endpoint {
 586                offset: usize,
 587                marker: char,
 588            }
 589
 590            let ranges = ranges.borrow();
 591            let mut endpoints = Vec::new();
 592            for range in ranges.command_outputs.values() {
 593                endpoints.push(Endpoint {
 594                    offset: range.start.to_offset(buffer),
 595                    marker: '',
 596                });
 597            }
 598            for range in ranges.parsed_commands.iter() {
 599                endpoints.push(Endpoint {
 600                    offset: range.start.to_offset(buffer),
 601                    marker: '«',
 602                });
 603            }
 604            for range in ranges.output_sections.iter() {
 605                endpoints.push(Endpoint {
 606                    offset: range.start.to_offset(buffer),
 607                    marker: '',
 608                });
 609            }
 610
 611            for range in ranges.output_sections.iter() {
 612                endpoints.push(Endpoint {
 613                    offset: range.end.to_offset(buffer),
 614                    marker: '',
 615                });
 616            }
 617            for range in ranges.parsed_commands.iter() {
 618                endpoints.push(Endpoint {
 619                    offset: range.end.to_offset(buffer),
 620                    marker: '»',
 621                });
 622            }
 623            for range in ranges.command_outputs.values() {
 624                endpoints.push(Endpoint {
 625                    offset: range.end.to_offset(buffer),
 626                    marker: '',
 627                });
 628            }
 629
 630            endpoints.sort_by_key(|endpoint| endpoint.offset);
 631            let mut offset = 0;
 632            for endpoint in endpoints {
 633                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 634                actual_marked_text.push(endpoint.marker);
 635                offset = endpoint.offset;
 636            }
 637            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 638        });
 639
 640        assert_eq!(actual_marked_text, expected_marked_text);
 641    }
 642}
 643
 644#[gpui::test]
 645async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 646    cx.update(prompt_library::init);
 647    let mut settings_store = cx.update(SettingsStore::test);
 648    cx.update(|cx| {
 649        settings_store
 650            .set_user_settings(
 651                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 652                cx,
 653            )
 654            .unwrap()
 655    });
 656    cx.set_global(settings_store);
 657    cx.update(language::init);
 658    cx.update(Project::init_settings);
 659    let fs = FakeFs::new(cx.executor());
 660    let project = Project::test(fs, [Path::new("/root")], cx).await;
 661    cx.update(LanguageModelRegistry::test);
 662
 663    cx.update(assistant_panel::init);
 664    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 665
 666    // Create a new context
 667    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 668    let context = cx.new_model(|cx| {
 669        Context::local(
 670            registry.clone(),
 671            Some(project),
 672            None,
 673            prompt_builder.clone(),
 674            cx,
 675        )
 676    });
 677
 678    // Insert an assistant message to simulate a response.
 679    let assistant_message_id = context.update(cx, |context, cx| {
 680        let user_message_id = context.messages(cx).next().unwrap().id;
 681        context
 682            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 683            .unwrap()
 684            .id
 685    });
 686
 687    // No edit tags
 688    edit(
 689        &context,
 690        "
 691
 692        «one
 693        two
 694        »",
 695        cx,
 696    );
 697    expect_patches(
 698        &context,
 699        "
 700
 701        one
 702        two
 703        ",
 704        &[],
 705        cx,
 706    );
 707
 708    // Partial edit step tag is added
 709    edit(
 710        &context,
 711        "
 712
 713        one
 714        two
 715        «
 716        <patch»",
 717        cx,
 718    );
 719    expect_patches(
 720        &context,
 721        "
 722
 723        one
 724        two
 725
 726        <patch",
 727        &[],
 728        cx,
 729    );
 730
 731    // The rest of the step tag is added. The unclosed
 732    // step is treated as incomplete.
 733    edit(
 734        &context,
 735        "
 736
 737        one
 738        two
 739
 740        <patch«>
 741        <edit>»",
 742        cx,
 743    );
 744    expect_patches(
 745        &context,
 746        "
 747
 748        one
 749        two
 750
 751        «<patch>
 752        <edit>»",
 753        &[&[]],
 754        cx,
 755    );
 756
 757    // The full patch is added
 758    edit(
 759        &context,
 760        "
 761
 762        one
 763        two
 764
 765        <patch>
 766        <edit>«
 767        <description>add a `two` function</description>
 768        <path>src/lib.rs</path>
 769        <operation>insert_after</operation>
 770        <old_text>fn one</old_text>
 771        <new_text>
 772        fn two() {}
 773        </new_text>
 774        </edit>
 775        </patch>
 776
 777        also,»",
 778        cx,
 779    );
 780    expect_patches(
 781        &context,
 782        "
 783
 784        one
 785        two
 786
 787        «<patch>
 788        <edit>
 789        <description>add a `two` function</description>
 790        <path>src/lib.rs</path>
 791        <operation>insert_after</operation>
 792        <old_text>fn one</old_text>
 793        <new_text>
 794        fn two() {}
 795        </new_text>
 796        </edit>
 797        </patch>
 798        »
 799        also,",
 800        &[&[AssistantEdit {
 801            path: "src/lib.rs".into(),
 802            kind: AssistantEditKind::InsertAfter {
 803                old_text: "fn one".into(),
 804                new_text: "fn two() {}".into(),
 805                description: Some("add a `two` function".into()),
 806            },
 807        }]],
 808        cx,
 809    );
 810
 811    // The step is manually edited.
 812    edit(
 813        &context,
 814        "
 815
 816        one
 817        two
 818
 819        <patch>
 820        <edit>
 821        <description>add a `two` function</description>
 822        <path>src/lib.rs</path>
 823        <operation>insert_after</operation>
 824        <old_text>«fn zero»</old_text>
 825        <new_text>
 826        fn two() {}
 827        </new_text>
 828        </edit>
 829        </patch>
 830
 831        also,",
 832        cx,
 833    );
 834    expect_patches(
 835        &context,
 836        "
 837
 838        one
 839        two
 840
 841        «<patch>
 842        <edit>
 843        <description>add a `two` function</description>
 844        <path>src/lib.rs</path>
 845        <operation>insert_after</operation>
 846        <old_text>fn zero</old_text>
 847        <new_text>
 848        fn two() {}
 849        </new_text>
 850        </edit>
 851        </patch>
 852        »
 853        also,",
 854        &[&[AssistantEdit {
 855            path: "src/lib.rs".into(),
 856            kind: AssistantEditKind::InsertAfter {
 857                old_text: "fn zero".into(),
 858                new_text: "fn two() {}".into(),
 859                description: Some("add a `two` function".into()),
 860            },
 861        }]],
 862        cx,
 863    );
 864
 865    // When setting the message role to User, the steps are cleared.
 866    context.update(cx, |context, cx| {
 867        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 868        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 869    });
 870    expect_patches(
 871        &context,
 872        "
 873
 874        one
 875        two
 876
 877        <patch>
 878        <edit>
 879        <description>add a `two` function</description>
 880        <path>src/lib.rs</path>
 881        <operation>insert_after</operation>
 882        <old_text>fn zero</old_text>
 883        <new_text>
 884        fn two() {}
 885        </new_text>
 886        </edit>
 887        </patch>
 888
 889        also,",
 890        &[],
 891        cx,
 892    );
 893
 894    // When setting the message role back to Assistant, the steps are reparsed.
 895    context.update(cx, |context, cx| {
 896        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 897    });
 898    expect_patches(
 899        &context,
 900        "
 901
 902        one
 903        two
 904
 905        «<patch>
 906        <edit>
 907        <description>add a `two` function</description>
 908        <path>src/lib.rs</path>
 909        <operation>insert_after</operation>
 910        <old_text>fn zero</old_text>
 911        <new_text>
 912        fn two() {}
 913        </new_text>
 914        </edit>
 915        </patch>
 916        »
 917        also,",
 918        &[&[AssistantEdit {
 919            path: "src/lib.rs".into(),
 920            kind: AssistantEditKind::InsertAfter {
 921                old_text: "fn zero".into(),
 922                new_text: "fn two() {}".into(),
 923                description: Some("add a `two` function".into()),
 924            },
 925        }]],
 926        cx,
 927    );
 928
 929    // Ensure steps are re-parsed when deserializing.
 930    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 931    let deserialized_context = cx.new_model(|cx| {
 932        Context::deserialize(
 933            serialized_context,
 934            Default::default(),
 935            registry.clone(),
 936            prompt_builder.clone(),
 937            None,
 938            None,
 939            cx,
 940        )
 941    });
 942    expect_patches(
 943        &deserialized_context,
 944        "
 945
 946        one
 947        two
 948
 949        «<patch>
 950        <edit>
 951        <description>add a `two` function</description>
 952        <path>src/lib.rs</path>
 953        <operation>insert_after</operation>
 954        <old_text>fn zero</old_text>
 955        <new_text>
 956        fn two() {}
 957        </new_text>
 958        </edit>
 959        </patch>
 960        »
 961        also,",
 962        &[&[AssistantEdit {
 963            path: "src/lib.rs".into(),
 964            kind: AssistantEditKind::InsertAfter {
 965                old_text: "fn zero".into(),
 966                new_text: "fn two() {}".into(),
 967                description: Some("add a `two` function".into()),
 968            },
 969        }]],
 970        cx,
 971    );
 972
 973    fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
 974        context.update(cx, |context, cx| {
 975            context.buffer.update(cx, |buffer, cx| {
 976                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
 977            });
 978        });
 979        cx.executor().run_until_parked();
 980    }
 981
 982    #[track_caller]
 983    fn expect_patches(
 984        context: &Model<Context>,
 985        expected_marked_text: &str,
 986        expected_suggestions: &[&[AssistantEdit]],
 987        cx: &mut TestAppContext,
 988    ) {
 989        let expected_marked_text = expected_marked_text.unindent();
 990        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
 991
 992        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
 993            context.buffer.read_with(cx, |buffer, _| {
 994                let ranges = context
 995                    .patches
 996                    .iter()
 997                    .map(|entry| entry.range.to_offset(buffer))
 998                    .collect::<Vec<_>>();
 999                (
1000                    buffer.text(),
1001                    ranges,
1002                    context
1003                        .patches
1004                        .iter()
1005                        .map(|step| step.edits.clone())
1006                        .collect::<Vec<_>>(),
1007                )
1008            })
1009        });
1010
1011        assert_eq!(buffer_text, expected_text);
1012
1013        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1014        assert_eq!(actual_marked_text, expected_marked_text);
1015
1016        assert_eq!(
1017            patches
1018                .iter()
1019                .map(|patch| {
1020                    patch
1021                        .iter()
1022                        .map(|edit| {
1023                            let edit = edit.as_ref().unwrap();
1024                            AssistantEdit {
1025                                path: edit.path.clone(),
1026                                kind: edit.kind.clone(),
1027                            }
1028                        })
1029                        .collect::<Vec<_>>()
1030                })
1031                .collect::<Vec<_>>(),
1032            expected_suggestions
1033        );
1034    }
1035}
1036
1037#[gpui::test]
1038async fn test_serialization(cx: &mut TestAppContext) {
1039    let settings_store = cx.update(SettingsStore::test);
1040    cx.set_global(settings_store);
1041    cx.update(LanguageModelRegistry::test);
1042    cx.update(assistant_panel::init);
1043    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1044    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1045    let context =
1046        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
1047    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1048    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1049    let message_1 = context.update(cx, |context, cx| {
1050        context
1051            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1052            .unwrap()
1053    });
1054    let message_2 = context.update(cx, |context, cx| {
1055        context
1056            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1057            .unwrap()
1058    });
1059    buffer.update(cx, |buffer, cx| {
1060        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1061        buffer.finalize_last_transaction();
1062    });
1063    let _message_3 = context.update(cx, |context, cx| {
1064        context
1065            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1066            .unwrap()
1067    });
1068    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1069    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1070    assert_eq!(
1071        cx.read(|cx| messages(&context, cx)),
1072        [
1073            (message_0, Role::User, 0..2),
1074            (message_1.id, Role::Assistant, 2..6),
1075            (message_2.id, Role::System, 6..6),
1076        ]
1077    );
1078
1079    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1080    let deserialized_context = cx.new_model(|cx| {
1081        Context::deserialize(
1082            serialized_context,
1083            Default::default(),
1084            registry.clone(),
1085            prompt_builder.clone(),
1086            None,
1087            None,
1088            cx,
1089        )
1090    });
1091    let deserialized_buffer =
1092        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1093    assert_eq!(
1094        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1095        "a\nb\nc\n"
1096    );
1097    assert_eq!(
1098        cx.read(|cx| messages(&deserialized_context, cx)),
1099        [
1100            (message_0, Role::User, 0..2),
1101            (message_1.id, Role::Assistant, 2..6),
1102            (message_2.id, Role::System, 6..6),
1103        ]
1104    );
1105}
1106
1107#[gpui::test(iterations = 100)]
1108async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1109    let min_peers = env::var("MIN_PEERS")
1110        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1111        .unwrap_or(2);
1112    let max_peers = env::var("MAX_PEERS")
1113        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1114        .unwrap_or(5);
1115    let operations = env::var("OPERATIONS")
1116        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1117        .unwrap_or(50);
1118
1119    let settings_store = cx.update(SettingsStore::test);
1120    cx.set_global(settings_store);
1121    cx.update(LanguageModelRegistry::test);
1122
1123    cx.update(assistant_panel::init);
1124    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1125    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1126    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1127    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1128
1129    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1130    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1131    let mut contexts = Vec::new();
1132
1133    let num_peers = rng.gen_range(min_peers..=max_peers);
1134    let context_id = ContextId::new();
1135    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1136    for i in 0..num_peers {
1137        let context = cx.new_model(|cx| {
1138            Context::new(
1139                context_id.clone(),
1140                i as ReplicaId,
1141                language::Capability::ReadWrite,
1142                registry.clone(),
1143                prompt_builder.clone(),
1144                None,
1145                None,
1146                cx,
1147            )
1148        });
1149
1150        cx.update(|cx| {
1151            cx.subscribe(&context, {
1152                let network = network.clone();
1153                move |_, event, _| {
1154                    if let ContextEvent::Operation(op) = event {
1155                        network
1156                            .lock()
1157                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1158                    }
1159                }
1160            })
1161            .detach();
1162        });
1163
1164        contexts.push(context);
1165        network.lock().add_peer(i as ReplicaId);
1166    }
1167
1168    let mut mutation_count = operations;
1169
1170    while mutation_count > 0
1171        || !network.lock().is_idle()
1172        || network.lock().contains_disconnected_peers()
1173    {
1174        let context_index = rng.gen_range(0..contexts.len());
1175        let context = &contexts[context_index];
1176
1177        match rng.gen_range(0..100) {
1178            0..=29 if mutation_count > 0 => {
1179                log::info!("Context {}: edit buffer", context_index);
1180                context.update(cx, |context, cx| {
1181                    context
1182                        .buffer
1183                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1184                });
1185                mutation_count -= 1;
1186            }
1187            30..=44 if mutation_count > 0 => {
1188                context.update(cx, |context, cx| {
1189                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1190                    log::info!("Context {}: split message at {:?}", context_index, range);
1191                    context.split_message(range, cx);
1192                });
1193                mutation_count -= 1;
1194            }
1195            45..=59 if mutation_count > 0 => {
1196                context.update(cx, |context, cx| {
1197                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1198                        let role = *[Role::User, Role::Assistant, Role::System]
1199                            .choose(&mut rng)
1200                            .unwrap();
1201                        log::info!(
1202                            "Context {}: insert message after {:?} with {:?}",
1203                            context_index,
1204                            message.id,
1205                            role
1206                        );
1207                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1208                    }
1209                });
1210                mutation_count -= 1;
1211            }
1212            60..=74 if mutation_count > 0 => {
1213                context.update(cx, |context, cx| {
1214                    let command_text = "/".to_string()
1215                        + slash_commands
1216                            .command_names()
1217                            .choose(&mut rng)
1218                            .unwrap()
1219                            .clone()
1220                            .as_ref();
1221
1222                    let command_range = context.buffer.update(cx, |buffer, cx| {
1223                        let offset = buffer.random_byte_range(0, &mut rng).start;
1224                        buffer.edit(
1225                            [(offset..offset, format!("\n{}\n", command_text))],
1226                            None,
1227                            cx,
1228                        );
1229                        offset + 1..offset + 1 + command_text.len()
1230                    });
1231
1232                    let output_text = RandomCharIter::new(&mut rng)
1233                        .filter(|c| *c != '\r')
1234                        .take(10)
1235                        .collect::<String>();
1236
1237                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1238                        role: Role::User,
1239                        merge_same_roles: true,
1240                    })];
1241
1242                    let num_sections = rng.gen_range(0..=3);
1243                    let mut section_start = 0;
1244                    for _ in 0..num_sections {
1245                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1246                        while !output_text.is_char_boundary(section_end) {
1247                            section_end += 1;
1248                        }
1249                        events.push(Ok(SlashCommandEvent::StartSection {
1250                            icon: IconName::Ai,
1251                            label: "section".into(),
1252                            metadata: None,
1253                        }));
1254                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1255                            text: output_text[section_start..section_end].to_string(),
1256                            run_commands_in_text: false,
1257                        })));
1258                        events.push(Ok(SlashCommandEvent::EndSection { metadata: None }));
1259                        section_start = section_end;
1260                    }
1261
1262                    if section_start < output_text.len() {
1263                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1264                            text: output_text[section_start..].to_string(),
1265                            run_commands_in_text: false,
1266                        })));
1267                    }
1268
1269                    log::info!(
1270                        "Context {}: insert slash command output at {:?} with {:?} events",
1271                        context_index,
1272                        command_range,
1273                        events.len()
1274                    );
1275
1276                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1277                        ..context.buffer.read(cx).anchor_after(command_range.end);
1278                    context.insert_command_output(
1279                        command_range,
1280                        "/command",
1281                        Task::ready(Ok(stream::iter(events).boxed())),
1282                        true,
1283                        cx,
1284                    );
1285                });
1286                cx.run_until_parked();
1287                mutation_count -= 1;
1288            }
1289            75..=84 if mutation_count > 0 => {
1290                context.update(cx, |context, cx| {
1291                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1292                        let new_status = match rng.gen_range(0..3) {
1293                            0 => MessageStatus::Done,
1294                            1 => MessageStatus::Pending,
1295                            _ => MessageStatus::Error(SharedString::from("Random error")),
1296                        };
1297                        log::info!(
1298                            "Context {}: update message {:?} status to {:?}",
1299                            context_index,
1300                            message.id,
1301                            new_status
1302                        );
1303                        context.update_metadata(message.id, cx, |metadata| {
1304                            metadata.status = new_status;
1305                        });
1306                    }
1307                });
1308                mutation_count -= 1;
1309            }
1310            _ => {
1311                let replica_id = context_index as ReplicaId;
1312                if network.lock().is_disconnected(replica_id) {
1313                    network.lock().reconnect_peer(replica_id, 0);
1314
1315                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1316                        let host_context = &contexts[0].read(cx);
1317                        let guest_context = context.read(cx);
1318                        (
1319                            guest_context.serialize_ops(&host_context.version(cx), cx),
1320                            host_context.serialize_ops(&guest_context.version(cx), cx),
1321                        )
1322                    });
1323                    let ops_to_send = ops_to_send.await;
1324                    let ops_to_receive = ops_to_receive
1325                        .await
1326                        .into_iter()
1327                        .map(ContextOperation::from_proto)
1328                        .collect::<Result<Vec<_>>>()
1329                        .unwrap();
1330                    log::info!(
1331                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1332                        context_index,
1333                        ops_to_send.len(),
1334                        ops_to_receive.len()
1335                    );
1336
1337                    network.lock().broadcast(replica_id, ops_to_send);
1338                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1339                } else if rng.gen_bool(0.1) && replica_id != 0 {
1340                    log::info!("Context {}: disconnecting", context_index);
1341                    network.lock().disconnect_peer(replica_id);
1342                } else if network.lock().has_unreceived(replica_id) {
1343                    log::info!("Context {}: applying operations", context_index);
1344                    let ops = network.lock().receive(replica_id);
1345                    let ops = ops
1346                        .into_iter()
1347                        .map(ContextOperation::from_proto)
1348                        .collect::<Result<Vec<_>>>()
1349                        .unwrap();
1350                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1351                }
1352            }
1353        }
1354    }
1355
1356    cx.read(|cx| {
1357        let first_context = contexts[0].read(cx);
1358        for context in &contexts[1..] {
1359            let context = context.read(cx);
1360            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1361            assert_eq!(
1362                context.buffer.read(cx).text(),
1363                first_context.buffer.read(cx).text(),
1364                "Context {} text != Context 0 text",
1365                context.buffer.read(cx).replica_id()
1366            );
1367            assert_eq!(
1368                context.message_anchors,
1369                first_context.message_anchors,
1370                "Context {} messages != Context 0 messages",
1371                context.buffer.read(cx).replica_id()
1372            );
1373            assert_eq!(
1374                context.messages_metadata,
1375                first_context.messages_metadata,
1376                "Context {} message metadata != Context 0 message metadata",
1377                context.buffer.read(cx).replica_id()
1378            );
1379            assert_eq!(
1380                context.slash_command_output_sections,
1381                first_context.slash_command_output_sections,
1382                "Context {} slash command output sections != Context 0 slash command output sections",
1383                context.buffer.read(cx).replica_id()
1384            );
1385        }
1386    });
1387}
1388
1389#[gpui::test]
1390fn test_mark_cache_anchors(cx: &mut AppContext) {
1391    let settings_store = SettingsStore::test(cx);
1392    LanguageModelRegistry::test(cx);
1393    cx.set_global(settings_store);
1394    assistant_panel::init(cx);
1395    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1396    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1397    let context =
1398        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1399    let buffer = context.read(cx).buffer.clone();
1400
1401    // Create a test cache configuration
1402    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1403        max_cache_anchors: 3,
1404        should_speculate: true,
1405        min_total_token: 10,
1406    });
1407
1408    let message_1 = context.read(cx).message_anchors[0].clone();
1409
1410    context.update(cx, |context, cx| {
1411        context.mark_cache_anchors(cache_configuration, false, cx)
1412    });
1413
1414    assert_eq!(
1415        messages_cache(&context, cx)
1416            .iter()
1417            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1418            .count(),
1419        0,
1420        "Empty messages should not have any cache anchors."
1421    );
1422
1423    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1424    let message_2 = context
1425        .update(cx, |context, cx| {
1426            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1427        })
1428        .unwrap();
1429
1430    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1431    let message_3 = context
1432        .update(cx, |context, cx| {
1433            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1434        })
1435        .unwrap();
1436    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1437
1438    context.update(cx, |context, cx| {
1439        context.mark_cache_anchors(cache_configuration, false, cx)
1440    });
1441    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1442    assert_eq!(
1443        messages_cache(&context, cx)
1444            .iter()
1445            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1446            .count(),
1447        0,
1448        "Messages should not be marked for cache before going over the token minimum."
1449    );
1450    context.update(cx, |context, _| {
1451        context.token_count = Some(20);
1452    });
1453
1454    context.update(cx, |context, cx| {
1455        context.mark_cache_anchors(cache_configuration, true, cx)
1456    });
1457    assert_eq!(
1458        messages_cache(&context, cx)
1459            .iter()
1460            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1461            .collect::<Vec<bool>>(),
1462        vec![true, true, false],
1463        "Last message should not be an anchor on speculative request."
1464    );
1465
1466    context
1467        .update(cx, |context, cx| {
1468            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1469        })
1470        .unwrap();
1471
1472    context.update(cx, |context, cx| {
1473        context.mark_cache_anchors(cache_configuration, false, cx)
1474    });
1475    assert_eq!(
1476        messages_cache(&context, cx)
1477            .iter()
1478            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1479            .collect::<Vec<bool>>(),
1480        vec![false, true, true, false],
1481        "Most recent message should also be cached if not a speculative request."
1482    );
1483    context.update(cx, |context, cx| {
1484        context.update_cache_status_for_completion(cx)
1485    });
1486    assert_eq!(
1487        messages_cache(&context, cx)
1488            .iter()
1489            .map(|(_, cache)| cache
1490                .as_ref()
1491                .map_or(None, |cache| Some(cache.status.clone())))
1492            .collect::<Vec<Option<CacheStatus>>>(),
1493        vec![
1494            Some(CacheStatus::Cached),
1495            Some(CacheStatus::Cached),
1496            Some(CacheStatus::Cached),
1497            None
1498        ],
1499        "All user messages prior to anchor should be marked as cached."
1500    );
1501
1502    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1503    context.update(cx, |context, cx| {
1504        context.mark_cache_anchors(cache_configuration, false, cx)
1505    });
1506    assert_eq!(
1507        messages_cache(&context, cx)
1508            .iter()
1509            .map(|(_, cache)| cache
1510                .as_ref()
1511                .map_or(None, |cache| Some(cache.status.clone())))
1512            .collect::<Vec<Option<CacheStatus>>>(),
1513        vec![
1514            Some(CacheStatus::Cached),
1515            Some(CacheStatus::Cached),
1516            Some(CacheStatus::Pending),
1517            None
1518        ],
1519        "Modifying a message should invalidate it's cache but leave previous messages."
1520    );
1521    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1522    context.update(cx, |context, cx| {
1523        context.mark_cache_anchors(cache_configuration, false, cx)
1524    });
1525    assert_eq!(
1526        messages_cache(&context, cx)
1527            .iter()
1528            .map(|(_, cache)| cache
1529                .as_ref()
1530                .map_or(None, |cache| Some(cache.status.clone())))
1531            .collect::<Vec<Option<CacheStatus>>>(),
1532        vec![
1533            Some(CacheStatus::Pending),
1534            Some(CacheStatus::Pending),
1535            Some(CacheStatus::Pending),
1536            None
1537        ],
1538        "Modifying a message should invalidate all future messages."
1539    );
1540}
1541
1542fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1543    context
1544        .read(cx)
1545        .messages(cx)
1546        .map(|message| (message.id, message.role, message.offset_range))
1547        .collect()
1548}
1549
1550fn messages_cache(
1551    context: &Model<Context>,
1552    cx: &AppContext,
1553) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1554    context
1555        .read(cx)
1556        .messages(cx)
1557        .map(|message| (message.id, message.cache.clone()))
1558        .collect()
1559}
1560
1561#[derive(Clone)]
1562struct FakeSlashCommand(String);
1563
1564impl SlashCommand for FakeSlashCommand {
1565    fn name(&self) -> String {
1566        self.0.clone()
1567    }
1568
1569    fn description(&self) -> String {
1570        format!("Fake slash command: {}", self.0)
1571    }
1572
1573    fn menu_text(&self) -> String {
1574        format!("Run fake command: {}", self.0)
1575    }
1576
1577    fn complete_argument(
1578        self: Arc<Self>,
1579        _arguments: &[String],
1580        _cancel: Arc<AtomicBool>,
1581        _workspace: Option<WeakView<Workspace>>,
1582        _cx: &mut WindowContext,
1583    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1584        Task::ready(Ok(vec![]))
1585    }
1586
1587    fn requires_argument(&self) -> bool {
1588        false
1589    }
1590
1591    fn run(
1592        self: Arc<Self>,
1593        _arguments: &[String],
1594        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1595        _context_buffer: BufferSnapshot,
1596        _workspace: WeakView<Workspace>,
1597        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1598        _cx: &mut WindowContext,
1599    ) -> Task<SlashCommandResult> {
1600        Task::ready(Ok(SlashCommandOutput {
1601            text: format!("Executed fake command: {}", self.0),
1602            sections: vec![],
1603            run_commands_in_text: false,
1604        }
1605        .to_event_stream()))
1606    }
1607}