context_tests.rs

   1use crate::{
   2    AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
   3    ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  20use parking_lot::Mutex;
  21use pretty_assertions::assert_eq;
  22use project::Project;
  23use prompt_store::PromptBuilder;
  24use rand::prelude::*;
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{
  28    cell::RefCell,
  29    env,
  30    ops::Range,
  31    path::Path,
  32    rc::Rc,
  33    sync::{Arc, atomic::AtomicBool},
  34};
  35use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
  36use ui::{IconName, Window};
  37use unindent::Unindent;
  38use util::{
  39    RandomCharIter,
  40    test::{generate_marked_text, marked_text_ranges},
  41};
  42use workspace::Workspace;
  43
  44#[gpui::test]
  45fn test_inserting_and_removing_messages(cx: &mut App) {
  46    let settings_store = SettingsStore::test(cx);
  47    LanguageModelRegistry::test(cx);
  48    cx.set_global(settings_store);
  49    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  50    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  51    let context = cx.new(|cx| {
  52        AssistantContext::local(
  53            registry,
  54            None,
  55            None,
  56            prompt_builder.clone(),
  57            Arc::new(SlashCommandWorkingSet::default()),
  58            cx,
  59        )
  60    });
  61    let buffer = context.read(cx).buffer.clone();
  62
  63    let message_1 = context.read(cx).message_anchors[0].clone();
  64    assert_eq!(
  65        messages(&context, cx),
  66        vec![(message_1.id, Role::User, 0..0)]
  67    );
  68
  69    let message_2 = context.update(cx, |context, cx| {
  70        context
  71            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  72            .unwrap()
  73    });
  74    assert_eq!(
  75        messages(&context, cx),
  76        vec![
  77            (message_1.id, Role::User, 0..1),
  78            (message_2.id, Role::Assistant, 1..1)
  79        ]
  80    );
  81
  82    buffer.update(cx, |buffer, cx| {
  83        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  84    });
  85    assert_eq!(
  86        messages(&context, cx),
  87        vec![
  88            (message_1.id, Role::User, 0..2),
  89            (message_2.id, Role::Assistant, 2..3)
  90        ]
  91    );
  92
  93    let message_3 = context.update(cx, |context, cx| {
  94        context
  95            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  96            .unwrap()
  97    });
  98    assert_eq!(
  99        messages(&context, cx),
 100        vec![
 101            (message_1.id, Role::User, 0..2),
 102            (message_2.id, Role::Assistant, 2..4),
 103            (message_3.id, Role::User, 4..4)
 104        ]
 105    );
 106
 107    let message_4 = context.update(cx, |context, cx| {
 108        context
 109            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 110            .unwrap()
 111    });
 112    assert_eq!(
 113        messages(&context, cx),
 114        vec![
 115            (message_1.id, Role::User, 0..2),
 116            (message_2.id, Role::Assistant, 2..4),
 117            (message_4.id, Role::User, 4..5),
 118            (message_3.id, Role::User, 5..5),
 119        ]
 120    );
 121
 122    buffer.update(cx, |buffer, cx| {
 123        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 124    });
 125    assert_eq!(
 126        messages(&context, cx),
 127        vec![
 128            (message_1.id, Role::User, 0..2),
 129            (message_2.id, Role::Assistant, 2..4),
 130            (message_4.id, Role::User, 4..6),
 131            (message_3.id, Role::User, 6..7),
 132        ]
 133    );
 134
 135    // Deleting across message boundaries merges the messages.
 136    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 137    assert_eq!(
 138        messages(&context, cx),
 139        vec![
 140            (message_1.id, Role::User, 0..3),
 141            (message_3.id, Role::User, 3..4),
 142        ]
 143    );
 144
 145    // Undoing the deletion should also undo the merge.
 146    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 147    assert_eq!(
 148        messages(&context, cx),
 149        vec![
 150            (message_1.id, Role::User, 0..2),
 151            (message_2.id, Role::Assistant, 2..4),
 152            (message_4.id, Role::User, 4..6),
 153            (message_3.id, Role::User, 6..7),
 154        ]
 155    );
 156
 157    // Redoing the deletion should also redo the merge.
 158    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 159    assert_eq!(
 160        messages(&context, cx),
 161        vec![
 162            (message_1.id, Role::User, 0..3),
 163            (message_3.id, Role::User, 3..4),
 164        ]
 165    );
 166
 167    // Ensure we can still insert after a merged message.
 168    let message_5 = context.update(cx, |context, cx| {
 169        context
 170            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 171            .unwrap()
 172    });
 173    assert_eq!(
 174        messages(&context, cx),
 175        vec![
 176            (message_1.id, Role::User, 0..3),
 177            (message_5.id, Role::System, 3..4),
 178            (message_3.id, Role::User, 4..5)
 179        ]
 180    );
 181}
 182
 183#[gpui::test]
 184fn test_message_splitting(cx: &mut App) {
 185    let settings_store = SettingsStore::test(cx);
 186    cx.set_global(settings_store);
 187    LanguageModelRegistry::test(cx);
 188    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 189
 190    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 191    let context = cx.new(|cx| {
 192        AssistantContext::local(
 193            registry.clone(),
 194            None,
 195            None,
 196            prompt_builder.clone(),
 197            Arc::new(SlashCommandWorkingSet::default()),
 198            cx,
 199        )
 200    });
 201    let buffer = context.read(cx).buffer.clone();
 202
 203    let message_1 = context.read(cx).message_anchors[0].clone();
 204    assert_eq!(
 205        messages(&context, cx),
 206        vec![(message_1.id, Role::User, 0..0)]
 207    );
 208
 209    buffer.update(cx, |buffer, cx| {
 210        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 211    });
 212
 213    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 214    let message_2 = message_2.unwrap();
 215
 216    // We recycle newlines in the middle of a split message
 217    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 218    assert_eq!(
 219        messages(&context, cx),
 220        vec![
 221            (message_1.id, Role::User, 0..4),
 222            (message_2.id, Role::User, 4..16),
 223        ]
 224    );
 225
 226    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 227    let message_3 = message_3.unwrap();
 228
 229    // We don't recycle newlines at the end of a split message
 230    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 231    assert_eq!(
 232        messages(&context, cx),
 233        vec![
 234            (message_1.id, Role::User, 0..4),
 235            (message_3.id, Role::User, 4..5),
 236            (message_2.id, Role::User, 5..17),
 237        ]
 238    );
 239
 240    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 241    let message_4 = message_4.unwrap();
 242    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 243    assert_eq!(
 244        messages(&context, cx),
 245        vec![
 246            (message_1.id, Role::User, 0..4),
 247            (message_3.id, Role::User, 4..5),
 248            (message_2.id, Role::User, 5..9),
 249            (message_4.id, Role::User, 9..17),
 250        ]
 251    );
 252
 253    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 254    let message_5 = message_5.unwrap();
 255    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 256    assert_eq!(
 257        messages(&context, cx),
 258        vec![
 259            (message_1.id, Role::User, 0..4),
 260            (message_3.id, Role::User, 4..5),
 261            (message_2.id, Role::User, 5..9),
 262            (message_4.id, Role::User, 9..10),
 263            (message_5.id, Role::User, 10..18),
 264        ]
 265    );
 266
 267    let (message_6, message_7) =
 268        context.update(cx, |context, cx| context.split_message(14..16, cx));
 269    let message_6 = message_6.unwrap();
 270    let message_7 = message_7.unwrap();
 271    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 272    assert_eq!(
 273        messages(&context, cx),
 274        vec![
 275            (message_1.id, Role::User, 0..4),
 276            (message_3.id, Role::User, 4..5),
 277            (message_2.id, Role::User, 5..9),
 278            (message_4.id, Role::User, 9..10),
 279            (message_5.id, Role::User, 10..14),
 280            (message_6.id, Role::User, 14..17),
 281            (message_7.id, Role::User, 17..19),
 282        ]
 283    );
 284}
 285
 286#[gpui::test]
 287fn test_messages_for_offsets(cx: &mut App) {
 288    let settings_store = SettingsStore::test(cx);
 289    LanguageModelRegistry::test(cx);
 290    cx.set_global(settings_store);
 291    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 292    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 293    let context = cx.new(|cx| {
 294        AssistantContext::local(
 295            registry,
 296            None,
 297            None,
 298            prompt_builder.clone(),
 299            Arc::new(SlashCommandWorkingSet::default()),
 300            cx,
 301        )
 302    });
 303    let buffer = context.read(cx).buffer.clone();
 304
 305    let message_1 = context.read(cx).message_anchors[0].clone();
 306    assert_eq!(
 307        messages(&context, cx),
 308        vec![(message_1.id, Role::User, 0..0)]
 309    );
 310
 311    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 312    let message_2 = context
 313        .update(cx, |context, cx| {
 314            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 315        })
 316        .unwrap();
 317    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 318
 319    let message_3 = context
 320        .update(cx, |context, cx| {
 321            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 322        })
 323        .unwrap();
 324    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 325
 326    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 327    assert_eq!(
 328        messages(&context, cx),
 329        vec![
 330            (message_1.id, Role::User, 0..4),
 331            (message_2.id, Role::User, 4..8),
 332            (message_3.id, Role::User, 8..11)
 333        ]
 334    );
 335
 336    assert_eq!(
 337        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 338        [message_1.id, message_2.id, message_3.id]
 339    );
 340    assert_eq!(
 341        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 342        [message_1.id, message_3.id]
 343    );
 344
 345    let message_4 = context
 346        .update(cx, |context, cx| {
 347            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 348        })
 349        .unwrap();
 350    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 351    assert_eq!(
 352        messages(&context, cx),
 353        vec![
 354            (message_1.id, Role::User, 0..4),
 355            (message_2.id, Role::User, 4..8),
 356            (message_3.id, Role::User, 8..12),
 357            (message_4.id, Role::User, 12..12)
 358        ]
 359    );
 360    assert_eq!(
 361        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 362        [message_1.id, message_2.id, message_3.id, message_4.id]
 363    );
 364
 365    fn message_ids_for_offsets(
 366        context: &Entity<AssistantContext>,
 367        offsets: &[usize],
 368        cx: &App,
 369    ) -> Vec<MessageId> {
 370        context
 371            .read(cx)
 372            .messages_for_offsets(offsets.iter().copied(), cx)
 373            .into_iter()
 374            .map(|message| message.id)
 375            .collect()
 376    }
 377}
 378
 379#[gpui::test]
 380async fn test_slash_commands(cx: &mut TestAppContext) {
 381    let settings_store = cx.update(SettingsStore::test);
 382    cx.set_global(settings_store);
 383    cx.update(LanguageModelRegistry::test);
 384    cx.update(Project::init_settings);
 385    let fs = FakeFs::new(cx.background_executor.clone());
 386
 387    fs.insert_tree(
 388        "/test",
 389        json!({
 390            "src": {
 391                "lib.rs": "fn one() -> usize { 1 }",
 392                "main.rs": "
 393                    use crate::one;
 394                    fn main() { one(); }
 395                ".unindent(),
 396            }
 397        }),
 398    )
 399    .await;
 400
 401    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 402    slash_command_registry.register_command(FileSlashCommand, false);
 403
 404    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 405    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 406    let context = cx.new(|cx| {
 407        AssistantContext::local(
 408            registry.clone(),
 409            None,
 410            None,
 411            prompt_builder.clone(),
 412            Arc::new(SlashCommandWorkingSet::default()),
 413            cx,
 414        )
 415    });
 416
 417    #[derive(Default)]
 418    struct ContextRanges {
 419        parsed_commands: HashSet<Range<language::Anchor>>,
 420        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 421        output_sections: HashSet<Range<language::Anchor>>,
 422    }
 423
 424    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 425    context.update(cx, |_, cx| {
 426        cx.subscribe(&context, {
 427            let context_ranges = context_ranges.clone();
 428            move |context, _, event, _| {
 429                let mut context_ranges = context_ranges.borrow_mut();
 430                match event {
 431                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 432                        let command = context.invoked_slash_command(command_id).unwrap();
 433                        context_ranges
 434                            .command_outputs
 435                            .insert(*command_id, command.range.clone());
 436                    }
 437                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 438                        for range in removed {
 439                            context_ranges.parsed_commands.remove(range);
 440                        }
 441                        for command in updated {
 442                            context_ranges
 443                                .parsed_commands
 444                                .insert(command.source_range.clone());
 445                        }
 446                    }
 447                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 448                        context_ranges.output_sections.insert(section.range.clone());
 449                    }
 450                    _ => {}
 451                }
 452            }
 453        })
 454        .detach();
 455    });
 456
 457    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 458
 459    // Insert a slash command
 460    buffer.update(cx, |buffer, cx| {
 461        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 462    });
 463    assert_text_and_context_ranges(
 464        &buffer,
 465        &context_ranges,
 466        &"
 467        «/file src/lib.rs»"
 468            .unindent(),
 469        cx,
 470    );
 471
 472    // Edit the argument of the slash command.
 473    buffer.update(cx, |buffer, cx| {
 474        let edit_offset = buffer.text().find("lib.rs").unwrap();
 475        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 476    });
 477    assert_text_and_context_ranges(
 478        &buffer,
 479        &context_ranges,
 480        &"
 481        «/file src/main.rs»"
 482            .unindent(),
 483        cx,
 484    );
 485
 486    // Edit the name of the slash command, using one that doesn't exist.
 487    buffer.update(cx, |buffer, cx| {
 488        let edit_offset = buffer.text().find("/file").unwrap();
 489        buffer.edit(
 490            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 491            None,
 492            cx,
 493        );
 494    });
 495    assert_text_and_context_ranges(
 496        &buffer,
 497        &context_ranges,
 498        &"
 499        /unknown src/main.rs"
 500            .unindent(),
 501        cx,
 502    );
 503
 504    // Undoing the insertion of an non-existent slash command resorts the previous one.
 505    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 506    assert_text_and_context_ranges(
 507        &buffer,
 508        &context_ranges,
 509        &"
 510        «/file src/main.rs»"
 511            .unindent(),
 512        cx,
 513    );
 514
 515    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 516    context.update(cx, |context, cx| {
 517        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 518        context.insert_command_output(
 519            command_source_range,
 520            "file",
 521            Task::ready(Ok(command_output_rx.boxed())),
 522            true,
 523            cx,
 524        );
 525    });
 526    assert_text_and_context_ranges(
 527        &buffer,
 528        &context_ranges,
 529        &"
 530        ⟦«/file src/main.rs»
 531        …⟧
 532        "
 533        .unindent(),
 534        cx,
 535    );
 536
 537    command_output_tx
 538        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 539            icon: IconName::Ai,
 540            label: "src/main.rs".into(),
 541            metadata: None,
 542        }))
 543        .unwrap();
 544    command_output_tx
 545        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 546        .unwrap();
 547    cx.run_until_parked();
 548    assert_text_and_context_ranges(
 549        &buffer,
 550        &context_ranges,
 551        &"
 552        ⟦«/file src/main.rs»
 553        src/main.rs…⟧
 554        "
 555        .unindent(),
 556        cx,
 557    );
 558
 559    command_output_tx
 560        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 561        .unwrap();
 562    cx.run_until_parked();
 563    assert_text_and_context_ranges(
 564        &buffer,
 565        &context_ranges,
 566        &"
 567        ⟦«/file src/main.rs»
 568        src/main.rs
 569        fn main() {}…⟧
 570        "
 571        .unindent(),
 572        cx,
 573    );
 574
 575    command_output_tx
 576        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 577        .unwrap();
 578    cx.run_until_parked();
 579    assert_text_and_context_ranges(
 580        &buffer,
 581        &context_ranges,
 582        &"
 583        ⟦«/file src/main.rs»
 584        ⟪src/main.rs
 585        fn main() {}⟫…⟧
 586        "
 587        .unindent(),
 588        cx,
 589    );
 590
 591    drop(command_output_tx);
 592    cx.run_until_parked();
 593    assert_text_and_context_ranges(
 594        &buffer,
 595        &context_ranges,
 596        &"
 597        ⟦⟪src/main.rs
 598        fn main() {}⟫⟧
 599        "
 600        .unindent(),
 601        cx,
 602    );
 603
 604    #[track_caller]
 605    fn assert_text_and_context_ranges(
 606        buffer: &Entity<Buffer>,
 607        ranges: &RefCell<ContextRanges>,
 608        expected_marked_text: &str,
 609        cx: &mut TestAppContext,
 610    ) {
 611        let mut actual_marked_text = String::new();
 612        buffer.update(cx, |buffer, _| {
 613            struct Endpoint {
 614                offset: usize,
 615                marker: char,
 616            }
 617
 618            let ranges = ranges.borrow();
 619            let mut endpoints = Vec::new();
 620            for range in ranges.command_outputs.values() {
 621                endpoints.push(Endpoint {
 622                    offset: range.start.to_offset(buffer),
 623                    marker: '',
 624                });
 625            }
 626            for range in ranges.parsed_commands.iter() {
 627                endpoints.push(Endpoint {
 628                    offset: range.start.to_offset(buffer),
 629                    marker: '«',
 630                });
 631            }
 632            for range in ranges.output_sections.iter() {
 633                endpoints.push(Endpoint {
 634                    offset: range.start.to_offset(buffer),
 635                    marker: '',
 636                });
 637            }
 638
 639            for range in ranges.output_sections.iter() {
 640                endpoints.push(Endpoint {
 641                    offset: range.end.to_offset(buffer),
 642                    marker: '',
 643                });
 644            }
 645            for range in ranges.parsed_commands.iter() {
 646                endpoints.push(Endpoint {
 647                    offset: range.end.to_offset(buffer),
 648                    marker: '»',
 649                });
 650            }
 651            for range in ranges.command_outputs.values() {
 652                endpoints.push(Endpoint {
 653                    offset: range.end.to_offset(buffer),
 654                    marker: '',
 655                });
 656            }
 657
 658            endpoints.sort_by_key(|endpoint| endpoint.offset);
 659            let mut offset = 0;
 660            for endpoint in endpoints {
 661                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 662                actual_marked_text.push(endpoint.marker);
 663                offset = endpoint.offset;
 664            }
 665            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 666        });
 667
 668        assert_eq!(actual_marked_text, expected_marked_text);
 669    }
 670}
 671
 672#[gpui::test]
 673async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 674    cx.update(prompt_store::init);
 675    let mut settings_store = cx.update(SettingsStore::test);
 676    cx.update(|cx| {
 677        settings_store
 678            .set_user_settings(
 679                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 680                cx,
 681            )
 682            .unwrap()
 683    });
 684    cx.set_global(settings_store);
 685    cx.update(language::init);
 686    cx.update(Project::init_settings);
 687    let fs = FakeFs::new(cx.executor());
 688    let project = Project::test(fs, [Path::new("/root")], cx).await;
 689    cx.update(LanguageModelRegistry::test);
 690
 691    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 692
 693    // Create a new context
 694    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 695    let context = cx.new(|cx| {
 696        AssistantContext::local(
 697            registry.clone(),
 698            Some(project),
 699            None,
 700            prompt_builder.clone(),
 701            Arc::new(SlashCommandWorkingSet::default()),
 702            cx,
 703        )
 704    });
 705
 706    // Insert an assistant message to simulate a response.
 707    let assistant_message_id = context.update(cx, |context, cx| {
 708        let user_message_id = context.messages(cx).next().unwrap().id;
 709        context
 710            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 711            .unwrap()
 712            .id
 713    });
 714
 715    // No edit tags
 716    edit(
 717        &context,
 718        "
 719
 720        «one
 721        two
 722        »",
 723        cx,
 724    );
 725    expect_patches(
 726        &context,
 727        "
 728
 729        one
 730        two
 731        ",
 732        &[],
 733        cx,
 734    );
 735
 736    // Partial edit step tag is added
 737    edit(
 738        &context,
 739        "
 740
 741        one
 742        two
 743        «
 744        <patch»",
 745        cx,
 746    );
 747    expect_patches(
 748        &context,
 749        "
 750
 751        one
 752        two
 753
 754        <patch",
 755        &[],
 756        cx,
 757    );
 758
 759    // The rest of the step tag is added. The unclosed
 760    // step is treated as incomplete.
 761    edit(
 762        &context,
 763        "
 764
 765        one
 766        two
 767
 768        <patch«>
 769        <edit>»",
 770        cx,
 771    );
 772    expect_patches(
 773        &context,
 774        "
 775
 776        one
 777        two
 778
 779        «<patch>
 780        <edit>»",
 781        &[&[]],
 782        cx,
 783    );
 784
 785    // The full patch is added
 786    edit(
 787        &context,
 788        "
 789
 790        one
 791        two
 792
 793        <patch>
 794        <edit>«
 795        <description>add a `two` function</description>
 796        <path>src/lib.rs</path>
 797        <operation>insert_after</operation>
 798        <old_text>fn one</old_text>
 799        <new_text>
 800        fn two() {}
 801        </new_text>
 802        </edit>
 803        </patch>
 804
 805        also,»",
 806        cx,
 807    );
 808    expect_patches(
 809        &context,
 810        "
 811
 812        one
 813        two
 814
 815        «<patch>
 816        <edit>
 817        <description>add a `two` function</description>
 818        <path>src/lib.rs</path>
 819        <operation>insert_after</operation>
 820        <old_text>fn one</old_text>
 821        <new_text>
 822        fn two() {}
 823        </new_text>
 824        </edit>
 825        </patch>
 826        »
 827        also,",
 828        &[&[AssistantEdit {
 829            path: "src/lib.rs".into(),
 830            kind: AssistantEditKind::InsertAfter {
 831                old_text: "fn one".into(),
 832                new_text: "fn two() {}".into(),
 833                description: Some("add a `two` function".into()),
 834            },
 835        }]],
 836        cx,
 837    );
 838
 839    // The step is manually edited.
 840    edit(
 841        &context,
 842        "
 843
 844        one
 845        two
 846
 847        <patch>
 848        <edit>
 849        <description>add a `two` function</description>
 850        <path>src/lib.rs</path>
 851        <operation>insert_after</operation>
 852        <old_text>«fn zero»</old_text>
 853        <new_text>
 854        fn two() {}
 855        </new_text>
 856        </edit>
 857        </patch>
 858
 859        also,",
 860        cx,
 861    );
 862    expect_patches(
 863        &context,
 864        "
 865
 866        one
 867        two
 868
 869        «<patch>
 870        <edit>
 871        <description>add a `two` function</description>
 872        <path>src/lib.rs</path>
 873        <operation>insert_after</operation>
 874        <old_text>fn zero</old_text>
 875        <new_text>
 876        fn two() {}
 877        </new_text>
 878        </edit>
 879        </patch>
 880        »
 881        also,",
 882        &[&[AssistantEdit {
 883            path: "src/lib.rs".into(),
 884            kind: AssistantEditKind::InsertAfter {
 885                old_text: "fn zero".into(),
 886                new_text: "fn two() {}".into(),
 887                description: Some("add a `two` function".into()),
 888            },
 889        }]],
 890        cx,
 891    );
 892
 893    // When setting the message role to User, the steps are cleared.
 894    context.update(cx, |context, cx| {
 895        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 896        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 897    });
 898    expect_patches(
 899        &context,
 900        "
 901
 902        one
 903        two
 904
 905        <patch>
 906        <edit>
 907        <description>add a `two` function</description>
 908        <path>src/lib.rs</path>
 909        <operation>insert_after</operation>
 910        <old_text>fn zero</old_text>
 911        <new_text>
 912        fn two() {}
 913        </new_text>
 914        </edit>
 915        </patch>
 916
 917        also,",
 918        &[],
 919        cx,
 920    );
 921
 922    // When setting the message role back to Assistant, the steps are reparsed.
 923    context.update(cx, |context, cx| {
 924        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 925    });
 926    expect_patches(
 927        &context,
 928        "
 929
 930        one
 931        two
 932
 933        «<patch>
 934        <edit>
 935        <description>add a `two` function</description>
 936        <path>src/lib.rs</path>
 937        <operation>insert_after</operation>
 938        <old_text>fn zero</old_text>
 939        <new_text>
 940        fn two() {}
 941        </new_text>
 942        </edit>
 943        </patch>
 944        »
 945        also,",
 946        &[&[AssistantEdit {
 947            path: "src/lib.rs".into(),
 948            kind: AssistantEditKind::InsertAfter {
 949                old_text: "fn zero".into(),
 950                new_text: "fn two() {}".into(),
 951                description: Some("add a `two` function".into()),
 952            },
 953        }]],
 954        cx,
 955    );
 956
 957    // Ensure steps are re-parsed when deserializing.
 958    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 959    let deserialized_context = cx.new(|cx| {
 960        AssistantContext::deserialize(
 961            serialized_context,
 962            Default::default(),
 963            registry.clone(),
 964            prompt_builder.clone(),
 965            Arc::new(SlashCommandWorkingSet::default()),
 966            None,
 967            None,
 968            cx,
 969        )
 970    });
 971    expect_patches(
 972        &deserialized_context,
 973        "
 974
 975        one
 976        two
 977
 978        «<patch>
 979        <edit>
 980        <description>add a `two` function</description>
 981        <path>src/lib.rs</path>
 982        <operation>insert_after</operation>
 983        <old_text>fn zero</old_text>
 984        <new_text>
 985        fn two() {}
 986        </new_text>
 987        </edit>
 988        </patch>
 989        »
 990        also,",
 991        &[&[AssistantEdit {
 992            path: "src/lib.rs".into(),
 993            kind: AssistantEditKind::InsertAfter {
 994                old_text: "fn zero".into(),
 995                new_text: "fn two() {}".into(),
 996                description: Some("add a `two` function".into()),
 997            },
 998        }]],
 999        cx,
1000    );
1001
1002    fn edit(
1003        context: &Entity<AssistantContext>,
1004        new_text_marked_with_edits: &str,
1005        cx: &mut TestAppContext,
1006    ) {
1007        context.update(cx, |context, cx| {
1008            context.buffer.update(cx, |buffer, cx| {
1009                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1010            });
1011        });
1012        cx.executor().run_until_parked();
1013    }
1014
1015    #[track_caller]
1016    fn expect_patches(
1017        context: &Entity<AssistantContext>,
1018        expected_marked_text: &str,
1019        expected_suggestions: &[&[AssistantEdit]],
1020        cx: &mut TestAppContext,
1021    ) {
1022        let expected_marked_text = expected_marked_text.unindent();
1023        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1024
1025        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1026            context.buffer.read_with(cx, |buffer, _| {
1027                let ranges = context
1028                    .patches
1029                    .iter()
1030                    .map(|entry| entry.range.to_offset(buffer))
1031                    .collect::<Vec<_>>();
1032                (
1033                    buffer.text(),
1034                    ranges,
1035                    context
1036                        .patches
1037                        .iter()
1038                        .map(|step| step.edits.clone())
1039                        .collect::<Vec<_>>(),
1040                )
1041            })
1042        });
1043
1044        assert_eq!(buffer_text, expected_text);
1045
1046        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1047        assert_eq!(actual_marked_text, expected_marked_text);
1048
1049        assert_eq!(
1050            patches
1051                .iter()
1052                .map(|patch| {
1053                    patch
1054                        .iter()
1055                        .map(|edit| {
1056                            let edit = edit.as_ref().unwrap();
1057                            AssistantEdit {
1058                                path: edit.path.clone(),
1059                                kind: edit.kind.clone(),
1060                            }
1061                        })
1062                        .collect::<Vec<_>>()
1063                })
1064                .collect::<Vec<_>>(),
1065            expected_suggestions
1066        );
1067    }
1068}
1069
1070#[gpui::test]
1071async fn test_serialization(cx: &mut TestAppContext) {
1072    let settings_store = cx.update(SettingsStore::test);
1073    cx.set_global(settings_store);
1074    cx.update(LanguageModelRegistry::test);
1075    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1076    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1077    let context = cx.new(|cx| {
1078        AssistantContext::local(
1079            registry.clone(),
1080            None,
1081            None,
1082            prompt_builder.clone(),
1083            Arc::new(SlashCommandWorkingSet::default()),
1084            cx,
1085        )
1086    });
1087    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1088    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1089    let message_1 = context.update(cx, |context, cx| {
1090        context
1091            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1092            .unwrap()
1093    });
1094    let message_2 = context.update(cx, |context, cx| {
1095        context
1096            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1097            .unwrap()
1098    });
1099    buffer.update(cx, |buffer, cx| {
1100        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1101        buffer.finalize_last_transaction();
1102    });
1103    let _message_3 = context.update(cx, |context, cx| {
1104        context
1105            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1106            .unwrap()
1107    });
1108    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1109    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1110    assert_eq!(
1111        cx.read(|cx| messages(&context, cx)),
1112        [
1113            (message_0, Role::User, 0..2),
1114            (message_1.id, Role::Assistant, 2..6),
1115            (message_2.id, Role::System, 6..6),
1116        ]
1117    );
1118
1119    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1120    let deserialized_context = cx.new(|cx| {
1121        AssistantContext::deserialize(
1122            serialized_context,
1123            Default::default(),
1124            registry.clone(),
1125            prompt_builder.clone(),
1126            Arc::new(SlashCommandWorkingSet::default()),
1127            None,
1128            None,
1129            cx,
1130        )
1131    });
1132    let deserialized_buffer =
1133        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1134    assert_eq!(
1135        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1136        "a\nb\nc\n"
1137    );
1138    assert_eq!(
1139        cx.read(|cx| messages(&deserialized_context, cx)),
1140        [
1141            (message_0, Role::User, 0..2),
1142            (message_1.id, Role::Assistant, 2..6),
1143            (message_2.id, Role::System, 6..6),
1144        ]
1145    );
1146}
1147
1148#[gpui::test(iterations = 100)]
1149async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1150    let min_peers = env::var("MIN_PEERS")
1151        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1152        .unwrap_or(2);
1153    let max_peers = env::var("MAX_PEERS")
1154        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1155        .unwrap_or(5);
1156    let operations = env::var("OPERATIONS")
1157        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1158        .unwrap_or(50);
1159
1160    let settings_store = cx.update(SettingsStore::test);
1161    cx.set_global(settings_store);
1162    cx.update(LanguageModelRegistry::test);
1163
1164    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1165    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1166    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1167    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1168
1169    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1170    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1171    let mut contexts = Vec::new();
1172
1173    let num_peers = rng.gen_range(min_peers..=max_peers);
1174    let context_id = ContextId::new();
1175    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1176    for i in 0..num_peers {
1177        let context = cx.new(|cx| {
1178            AssistantContext::new(
1179                context_id.clone(),
1180                i as ReplicaId,
1181                language::Capability::ReadWrite,
1182                registry.clone(),
1183                prompt_builder.clone(),
1184                Arc::new(SlashCommandWorkingSet::default()),
1185                None,
1186                None,
1187                cx,
1188            )
1189        });
1190
1191        cx.update(|cx| {
1192            cx.subscribe(&context, {
1193                let network = network.clone();
1194                move |_, event, _| {
1195                    if let ContextEvent::Operation(op) = event {
1196                        network
1197                            .lock()
1198                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1199                    }
1200                }
1201            })
1202            .detach();
1203        });
1204
1205        contexts.push(context);
1206        network.lock().add_peer(i as ReplicaId);
1207    }
1208
1209    let mut mutation_count = operations;
1210
1211    while mutation_count > 0
1212        || !network.lock().is_idle()
1213        || network.lock().contains_disconnected_peers()
1214    {
1215        let context_index = rng.gen_range(0..contexts.len());
1216        let context = &contexts[context_index];
1217
1218        match rng.gen_range(0..100) {
1219            0..=29 if mutation_count > 0 => {
1220                log::info!("Context {}: edit buffer", context_index);
1221                context.update(cx, |context, cx| {
1222                    context
1223                        .buffer
1224                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1225                });
1226                mutation_count -= 1;
1227            }
1228            30..=44 if mutation_count > 0 => {
1229                context.update(cx, |context, cx| {
1230                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1231                    log::info!("Context {}: split message at {:?}", context_index, range);
1232                    context.split_message(range, cx);
1233                });
1234                mutation_count -= 1;
1235            }
1236            45..=59 if mutation_count > 0 => {
1237                context.update(cx, |context, cx| {
1238                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1239                        let role = *[Role::User, Role::Assistant, Role::System]
1240                            .choose(&mut rng)
1241                            .unwrap();
1242                        log::info!(
1243                            "Context {}: insert message after {:?} with {:?}",
1244                            context_index,
1245                            message.id,
1246                            role
1247                        );
1248                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1249                    }
1250                });
1251                mutation_count -= 1;
1252            }
1253            60..=74 if mutation_count > 0 => {
1254                context.update(cx, |context, cx| {
1255                    let command_text = "/".to_string()
1256                        + slash_commands
1257                            .command_names()
1258                            .choose(&mut rng)
1259                            .unwrap()
1260                            .clone()
1261                            .as_ref();
1262
1263                    let command_range = context.buffer.update(cx, |buffer, cx| {
1264                        let offset = buffer.random_byte_range(0, &mut rng).start;
1265                        buffer.edit(
1266                            [(offset..offset, format!("\n{}\n", command_text))],
1267                            None,
1268                            cx,
1269                        );
1270                        offset + 1..offset + 1 + command_text.len()
1271                    });
1272
1273                    let output_text = RandomCharIter::new(&mut rng)
1274                        .filter(|c| *c != '\r')
1275                        .take(10)
1276                        .collect::<String>();
1277
1278                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1279                        role: Role::User,
1280                        merge_same_roles: true,
1281                    })];
1282
1283                    let num_sections = rng.gen_range(0..=3);
1284                    let mut section_start = 0;
1285                    for _ in 0..num_sections {
1286                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1287                        while !output_text.is_char_boundary(section_end) {
1288                            section_end += 1;
1289                        }
1290                        events.push(Ok(SlashCommandEvent::StartSection {
1291                            icon: IconName::Ai,
1292                            label: "section".into(),
1293                            metadata: None,
1294                        }));
1295                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1296                            text: output_text[section_start..section_end].to_string(),
1297                            run_commands_in_text: false,
1298                        })));
1299                        events.push(Ok(SlashCommandEvent::EndSection));
1300                        section_start = section_end;
1301                    }
1302
1303                    if section_start < output_text.len() {
1304                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1305                            text: output_text[section_start..].to_string(),
1306                            run_commands_in_text: false,
1307                        })));
1308                    }
1309
1310                    log::info!(
1311                        "Context {}: insert slash command output at {:?} with {:?} events",
1312                        context_index,
1313                        command_range,
1314                        events.len()
1315                    );
1316
1317                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1318                        ..context.buffer.read(cx).anchor_after(command_range.end);
1319                    context.insert_command_output(
1320                        command_range,
1321                        "/command",
1322                        Task::ready(Ok(stream::iter(events).boxed())),
1323                        true,
1324                        cx,
1325                    );
1326                });
1327                cx.run_until_parked();
1328                mutation_count -= 1;
1329            }
1330            75..=84 if mutation_count > 0 => {
1331                context.update(cx, |context, cx| {
1332                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1333                        let new_status = match rng.gen_range(0..3) {
1334                            0 => MessageStatus::Done,
1335                            1 => MessageStatus::Pending,
1336                            _ => MessageStatus::Error(SharedString::from("Random error")),
1337                        };
1338                        log::info!(
1339                            "Context {}: update message {:?} status to {:?}",
1340                            context_index,
1341                            message.id,
1342                            new_status
1343                        );
1344                        context.update_metadata(message.id, cx, |metadata| {
1345                            metadata.status = new_status;
1346                        });
1347                    }
1348                });
1349                mutation_count -= 1;
1350            }
1351            _ => {
1352                let replica_id = context_index as ReplicaId;
1353                if network.lock().is_disconnected(replica_id) {
1354                    network.lock().reconnect_peer(replica_id, 0);
1355
1356                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1357                        let host_context = &contexts[0].read(cx);
1358                        let guest_context = context.read(cx);
1359                        (
1360                            guest_context.serialize_ops(&host_context.version(cx), cx),
1361                            host_context.serialize_ops(&guest_context.version(cx), cx),
1362                        )
1363                    });
1364                    let ops_to_send = ops_to_send.await;
1365                    let ops_to_receive = ops_to_receive
1366                        .await
1367                        .into_iter()
1368                        .map(ContextOperation::from_proto)
1369                        .collect::<Result<Vec<_>>>()
1370                        .unwrap();
1371                    log::info!(
1372                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1373                        context_index,
1374                        ops_to_send.len(),
1375                        ops_to_receive.len()
1376                    );
1377
1378                    network.lock().broadcast(replica_id, ops_to_send);
1379                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1380                } else if rng.gen_bool(0.1) && replica_id != 0 {
1381                    log::info!("Context {}: disconnecting", context_index);
1382                    network.lock().disconnect_peer(replica_id);
1383                } else if network.lock().has_unreceived(replica_id) {
1384                    log::info!("Context {}: applying operations", context_index);
1385                    let ops = network.lock().receive(replica_id);
1386                    let ops = ops
1387                        .into_iter()
1388                        .map(ContextOperation::from_proto)
1389                        .collect::<Result<Vec<_>>>()
1390                        .unwrap();
1391                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1392                }
1393            }
1394        }
1395    }
1396
1397    cx.read(|cx| {
1398        let first_context = contexts[0].read(cx);
1399        for context in &contexts[1..] {
1400            let context = context.read(cx);
1401            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1402            assert_eq!(
1403                context.buffer.read(cx).text(),
1404                first_context.buffer.read(cx).text(),
1405                "Context {} text != Context 0 text",
1406                context.buffer.read(cx).replica_id()
1407            );
1408            assert_eq!(
1409                context.message_anchors,
1410                first_context.message_anchors,
1411                "Context {} messages != Context 0 messages",
1412                context.buffer.read(cx).replica_id()
1413            );
1414            assert_eq!(
1415                context.messages_metadata,
1416                first_context.messages_metadata,
1417                "Context {} message metadata != Context 0 message metadata",
1418                context.buffer.read(cx).replica_id()
1419            );
1420            assert_eq!(
1421                context.slash_command_output_sections,
1422                first_context.slash_command_output_sections,
1423                "Context {} slash command output sections != Context 0 slash command output sections",
1424                context.buffer.read(cx).replica_id()
1425            );
1426        }
1427    });
1428}
1429
1430#[gpui::test]
1431fn test_mark_cache_anchors(cx: &mut App) {
1432    let settings_store = SettingsStore::test(cx);
1433    LanguageModelRegistry::test(cx);
1434    cx.set_global(settings_store);
1435    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1436    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1437    let context = cx.new(|cx| {
1438        AssistantContext::local(
1439            registry,
1440            None,
1441            None,
1442            prompt_builder.clone(),
1443            Arc::new(SlashCommandWorkingSet::default()),
1444            cx,
1445        )
1446    });
1447    let buffer = context.read(cx).buffer.clone();
1448
1449    // Create a test cache configuration
1450    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1451        max_cache_anchors: 3,
1452        should_speculate: true,
1453        min_total_token: 10,
1454    });
1455
1456    let message_1 = context.read(cx).message_anchors[0].clone();
1457
1458    context.update(cx, |context, cx| {
1459        context.mark_cache_anchors(cache_configuration, false, cx)
1460    });
1461
1462    assert_eq!(
1463        messages_cache(&context, cx)
1464            .iter()
1465            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1466            .count(),
1467        0,
1468        "Empty messages should not have any cache anchors."
1469    );
1470
1471    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1472    let message_2 = context
1473        .update(cx, |context, cx| {
1474            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1475        })
1476        .unwrap();
1477
1478    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1479    let message_3 = context
1480        .update(cx, |context, cx| {
1481            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1482        })
1483        .unwrap();
1484    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1485
1486    context.update(cx, |context, cx| {
1487        context.mark_cache_anchors(cache_configuration, false, cx)
1488    });
1489    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1490    assert_eq!(
1491        messages_cache(&context, cx)
1492            .iter()
1493            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1494            .count(),
1495        0,
1496        "Messages should not be marked for cache before going over the token minimum."
1497    );
1498    context.update(cx, |context, _| {
1499        context.token_count = Some(20);
1500    });
1501
1502    context.update(cx, |context, cx| {
1503        context.mark_cache_anchors(cache_configuration, true, cx)
1504    });
1505    assert_eq!(
1506        messages_cache(&context, cx)
1507            .iter()
1508            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1509            .collect::<Vec<bool>>(),
1510        vec![true, true, false],
1511        "Last message should not be an anchor on speculative request."
1512    );
1513
1514    context
1515        .update(cx, |context, cx| {
1516            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1517        })
1518        .unwrap();
1519
1520    context.update(cx, |context, cx| {
1521        context.mark_cache_anchors(cache_configuration, false, cx)
1522    });
1523    assert_eq!(
1524        messages_cache(&context, cx)
1525            .iter()
1526            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1527            .collect::<Vec<bool>>(),
1528        vec![false, true, true, false],
1529        "Most recent message should also be cached if not a speculative request."
1530    );
1531    context.update(cx, |context, cx| {
1532        context.update_cache_status_for_completion(cx)
1533    });
1534    assert_eq!(
1535        messages_cache(&context, cx)
1536            .iter()
1537            .map(|(_, cache)| cache
1538                .as_ref()
1539                .map_or(None, |cache| Some(cache.status.clone())))
1540            .collect::<Vec<Option<CacheStatus>>>(),
1541        vec![
1542            Some(CacheStatus::Cached),
1543            Some(CacheStatus::Cached),
1544            Some(CacheStatus::Cached),
1545            None
1546        ],
1547        "All user messages prior to anchor should be marked as cached."
1548    );
1549
1550    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1551    context.update(cx, |context, cx| {
1552        context.mark_cache_anchors(cache_configuration, false, cx)
1553    });
1554    assert_eq!(
1555        messages_cache(&context, cx)
1556            .iter()
1557            .map(|(_, cache)| cache
1558                .as_ref()
1559                .map_or(None, |cache| Some(cache.status.clone())))
1560            .collect::<Vec<Option<CacheStatus>>>(),
1561        vec![
1562            Some(CacheStatus::Cached),
1563            Some(CacheStatus::Cached),
1564            Some(CacheStatus::Pending),
1565            None
1566        ],
1567        "Modifying a message should invalidate it's cache but leave previous messages."
1568    );
1569    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1570    context.update(cx, |context, cx| {
1571        context.mark_cache_anchors(cache_configuration, false, cx)
1572    });
1573    assert_eq!(
1574        messages_cache(&context, cx)
1575            .iter()
1576            .map(|(_, cache)| cache
1577                .as_ref()
1578                .map_or(None, |cache| Some(cache.status.clone())))
1579            .collect::<Vec<Option<CacheStatus>>>(),
1580        vec![
1581            Some(CacheStatus::Pending),
1582            Some(CacheStatus::Pending),
1583            Some(CacheStatus::Pending),
1584            None
1585        ],
1586        "Modifying a message should invalidate all future messages."
1587    );
1588}
1589
1590fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1591    context
1592        .read(cx)
1593        .messages(cx)
1594        .map(|message| (message.id, message.role, message.offset_range))
1595        .collect()
1596}
1597
1598fn messages_cache(
1599    context: &Entity<AssistantContext>,
1600    cx: &App,
1601) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1602    context
1603        .read(cx)
1604        .messages(cx)
1605        .map(|message| (message.id, message.cache.clone()))
1606        .collect()
1607}
1608
1609#[derive(Clone)]
1610struct FakeSlashCommand(String);
1611
1612impl SlashCommand for FakeSlashCommand {
1613    fn name(&self) -> String {
1614        self.0.clone()
1615    }
1616
1617    fn description(&self) -> String {
1618        format!("Fake slash command: {}", self.0)
1619    }
1620
1621    fn menu_text(&self) -> String {
1622        format!("Run fake command: {}", self.0)
1623    }
1624
1625    fn complete_argument(
1626        self: Arc<Self>,
1627        _arguments: &[String],
1628        _cancel: Arc<AtomicBool>,
1629        _workspace: Option<WeakEntity<Workspace>>,
1630        _window: &mut Window,
1631        _cx: &mut App,
1632    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1633        Task::ready(Ok(vec![]))
1634    }
1635
1636    fn requires_argument(&self) -> bool {
1637        false
1638    }
1639
1640    fn run(
1641        self: Arc<Self>,
1642        _arguments: &[String],
1643        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1644        _context_buffer: BufferSnapshot,
1645        _workspace: WeakEntity<Workspace>,
1646        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1647        _window: &mut Window,
1648        _cx: &mut App,
1649    ) -> Task<SlashCommandResult> {
1650        Task::ready(Ok(SlashCommandOutput {
1651            text: format!("Executed fake command: {}", self.0),
1652            sections: vec![],
1653            run_commands_in_text: false,
1654        }
1655        .to_event_stream()))
1656    }
1657}