context_tests.rs

   1use crate::{
   2    AssistantEdit, AssistantEditKind, CacheStatus, Context, ContextEvent, ContextId,
   3    ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use assistant_tool::ToolWorkingSet;
  12use collections::{HashMap, HashSet};
  13use fs::FakeFs;
  14use futures::{
  15    channel::mpsc,
  16    stream::{self, StreamExt},
  17};
  18use gpui::{prelude::*, AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  19use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  20use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  21use parking_lot::Mutex;
  22use pretty_assertions::assert_eq;
  23use project::Project;
  24use prompt_library::PromptBuilder;
  25use rand::prelude::*;
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{
  29    cell::RefCell,
  30    env,
  31    ops::Range,
  32    path::Path,
  33    rc::Rc,
  34    sync::{atomic::AtomicBool, Arc},
  35};
  36use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToOffset};
  37use ui::{IconName, WindowContext};
  38use unindent::Unindent;
  39use util::{
  40    test::{generate_marked_text, marked_text_ranges},
  41    RandomCharIter,
  42};
  43use workspace::Workspace;
  44
  45#[gpui::test]
  46fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  47    let settings_store = SettingsStore::test(cx);
  48    LanguageModelRegistry::test(cx);
  49    cx.set_global(settings_store);
  50    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  51    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  52    let context = cx.new_model(|cx| {
  53        Context::local(
  54            registry,
  55            None,
  56            None,
  57            prompt_builder.clone(),
  58            Arc::new(SlashCommandWorkingSet::default()),
  59            Arc::new(ToolWorkingSet::default()),
  60            cx,
  61        )
  62    });
  63    let buffer = context.read(cx).buffer.clone();
  64
  65    let message_1 = context.read(cx).message_anchors[0].clone();
  66    assert_eq!(
  67        messages(&context, cx),
  68        vec![(message_1.id, Role::User, 0..0)]
  69    );
  70
  71    let message_2 = context.update(cx, |context, cx| {
  72        context
  73            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  74            .unwrap()
  75    });
  76    assert_eq!(
  77        messages(&context, cx),
  78        vec![
  79            (message_1.id, Role::User, 0..1),
  80            (message_2.id, Role::Assistant, 1..1)
  81        ]
  82    );
  83
  84    buffer.update(cx, |buffer, cx| {
  85        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  86    });
  87    assert_eq!(
  88        messages(&context, cx),
  89        vec![
  90            (message_1.id, Role::User, 0..2),
  91            (message_2.id, Role::Assistant, 2..3)
  92        ]
  93    );
  94
  95    let message_3 = context.update(cx, |context, cx| {
  96        context
  97            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  98            .unwrap()
  99    });
 100    assert_eq!(
 101        messages(&context, cx),
 102        vec![
 103            (message_1.id, Role::User, 0..2),
 104            (message_2.id, Role::Assistant, 2..4),
 105            (message_3.id, Role::User, 4..4)
 106        ]
 107    );
 108
 109    let message_4 = context.update(cx, |context, cx| {
 110        context
 111            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 112            .unwrap()
 113    });
 114    assert_eq!(
 115        messages(&context, cx),
 116        vec![
 117            (message_1.id, Role::User, 0..2),
 118            (message_2.id, Role::Assistant, 2..4),
 119            (message_4.id, Role::User, 4..5),
 120            (message_3.id, Role::User, 5..5),
 121        ]
 122    );
 123
 124    buffer.update(cx, |buffer, cx| {
 125        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 126    });
 127    assert_eq!(
 128        messages(&context, cx),
 129        vec![
 130            (message_1.id, Role::User, 0..2),
 131            (message_2.id, Role::Assistant, 2..4),
 132            (message_4.id, Role::User, 4..6),
 133            (message_3.id, Role::User, 6..7),
 134        ]
 135    );
 136
 137    // Deleting across message boundaries merges the messages.
 138    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 139    assert_eq!(
 140        messages(&context, cx),
 141        vec![
 142            (message_1.id, Role::User, 0..3),
 143            (message_3.id, Role::User, 3..4),
 144        ]
 145    );
 146
 147    // Undoing the deletion should also undo the merge.
 148    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 149    assert_eq!(
 150        messages(&context, cx),
 151        vec![
 152            (message_1.id, Role::User, 0..2),
 153            (message_2.id, Role::Assistant, 2..4),
 154            (message_4.id, Role::User, 4..6),
 155            (message_3.id, Role::User, 6..7),
 156        ]
 157    );
 158
 159    // Redoing the deletion should also redo the merge.
 160    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 161    assert_eq!(
 162        messages(&context, cx),
 163        vec![
 164            (message_1.id, Role::User, 0..3),
 165            (message_3.id, Role::User, 3..4),
 166        ]
 167    );
 168
 169    // Ensure we can still insert after a merged message.
 170    let message_5 = context.update(cx, |context, cx| {
 171        context
 172            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 173            .unwrap()
 174    });
 175    assert_eq!(
 176        messages(&context, cx),
 177        vec![
 178            (message_1.id, Role::User, 0..3),
 179            (message_5.id, Role::System, 3..4),
 180            (message_3.id, Role::User, 4..5)
 181        ]
 182    );
 183}
 184
 185#[gpui::test]
 186fn test_message_splitting(cx: &mut AppContext) {
 187    let settings_store = SettingsStore::test(cx);
 188    cx.set_global(settings_store);
 189    LanguageModelRegistry::test(cx);
 190    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 191
 192    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 193    let context = cx.new_model(|cx| {
 194        Context::local(
 195            registry.clone(),
 196            None,
 197            None,
 198            prompt_builder.clone(),
 199            Arc::new(SlashCommandWorkingSet::default()),
 200            Arc::new(ToolWorkingSet::default()),
 201            cx,
 202        )
 203    });
 204    let buffer = context.read(cx).buffer.clone();
 205
 206    let message_1 = context.read(cx).message_anchors[0].clone();
 207    assert_eq!(
 208        messages(&context, cx),
 209        vec![(message_1.id, Role::User, 0..0)]
 210    );
 211
 212    buffer.update(cx, |buffer, cx| {
 213        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 214    });
 215
 216    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 217    let message_2 = message_2.unwrap();
 218
 219    // We recycle newlines in the middle of a split message
 220    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 221    assert_eq!(
 222        messages(&context, cx),
 223        vec![
 224            (message_1.id, Role::User, 0..4),
 225            (message_2.id, Role::User, 4..16),
 226        ]
 227    );
 228
 229    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 230    let message_3 = message_3.unwrap();
 231
 232    // We don't recycle newlines at the end of a split message
 233    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 234    assert_eq!(
 235        messages(&context, cx),
 236        vec![
 237            (message_1.id, Role::User, 0..4),
 238            (message_3.id, Role::User, 4..5),
 239            (message_2.id, Role::User, 5..17),
 240        ]
 241    );
 242
 243    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 244    let message_4 = message_4.unwrap();
 245    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 246    assert_eq!(
 247        messages(&context, cx),
 248        vec![
 249            (message_1.id, Role::User, 0..4),
 250            (message_3.id, Role::User, 4..5),
 251            (message_2.id, Role::User, 5..9),
 252            (message_4.id, Role::User, 9..17),
 253        ]
 254    );
 255
 256    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 257    let message_5 = message_5.unwrap();
 258    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 259    assert_eq!(
 260        messages(&context, cx),
 261        vec![
 262            (message_1.id, Role::User, 0..4),
 263            (message_3.id, Role::User, 4..5),
 264            (message_2.id, Role::User, 5..9),
 265            (message_4.id, Role::User, 9..10),
 266            (message_5.id, Role::User, 10..18),
 267        ]
 268    );
 269
 270    let (message_6, message_7) =
 271        context.update(cx, |context, cx| context.split_message(14..16, cx));
 272    let message_6 = message_6.unwrap();
 273    let message_7 = message_7.unwrap();
 274    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 275    assert_eq!(
 276        messages(&context, cx),
 277        vec![
 278            (message_1.id, Role::User, 0..4),
 279            (message_3.id, Role::User, 4..5),
 280            (message_2.id, Role::User, 5..9),
 281            (message_4.id, Role::User, 9..10),
 282            (message_5.id, Role::User, 10..14),
 283            (message_6.id, Role::User, 14..17),
 284            (message_7.id, Role::User, 17..19),
 285        ]
 286    );
 287}
 288
 289#[gpui::test]
 290fn test_messages_for_offsets(cx: &mut AppContext) {
 291    let settings_store = SettingsStore::test(cx);
 292    LanguageModelRegistry::test(cx);
 293    cx.set_global(settings_store);
 294    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 295    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 296    let context = cx.new_model(|cx| {
 297        Context::local(
 298            registry,
 299            None,
 300            None,
 301            prompt_builder.clone(),
 302            Arc::new(SlashCommandWorkingSet::default()),
 303            Arc::new(ToolWorkingSet::default()),
 304            cx,
 305        )
 306    });
 307    let buffer = context.read(cx).buffer.clone();
 308
 309    let message_1 = context.read(cx).message_anchors[0].clone();
 310    assert_eq!(
 311        messages(&context, cx),
 312        vec![(message_1.id, Role::User, 0..0)]
 313    );
 314
 315    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 316    let message_2 = context
 317        .update(cx, |context, cx| {
 318            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 319        })
 320        .unwrap();
 321    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 322
 323    let message_3 = context
 324        .update(cx, |context, cx| {
 325            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 326        })
 327        .unwrap();
 328    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 329
 330    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 331    assert_eq!(
 332        messages(&context, cx),
 333        vec![
 334            (message_1.id, Role::User, 0..4),
 335            (message_2.id, Role::User, 4..8),
 336            (message_3.id, Role::User, 8..11)
 337        ]
 338    );
 339
 340    assert_eq!(
 341        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 342        [message_1.id, message_2.id, message_3.id]
 343    );
 344    assert_eq!(
 345        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 346        [message_1.id, message_3.id]
 347    );
 348
 349    let message_4 = context
 350        .update(cx, |context, cx| {
 351            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 352        })
 353        .unwrap();
 354    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 355    assert_eq!(
 356        messages(&context, cx),
 357        vec![
 358            (message_1.id, Role::User, 0..4),
 359            (message_2.id, Role::User, 4..8),
 360            (message_3.id, Role::User, 8..12),
 361            (message_4.id, Role::User, 12..12)
 362        ]
 363    );
 364    assert_eq!(
 365        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 366        [message_1.id, message_2.id, message_3.id, message_4.id]
 367    );
 368
 369    fn message_ids_for_offsets(
 370        context: &Model<Context>,
 371        offsets: &[usize],
 372        cx: &AppContext,
 373    ) -> Vec<MessageId> {
 374        context
 375            .read(cx)
 376            .messages_for_offsets(offsets.iter().copied(), cx)
 377            .into_iter()
 378            .map(|message| message.id)
 379            .collect()
 380    }
 381}
 382
 383#[gpui::test]
 384async fn test_slash_commands(cx: &mut TestAppContext) {
 385    let settings_store = cx.update(SettingsStore::test);
 386    cx.set_global(settings_store);
 387    cx.update(LanguageModelRegistry::test);
 388    cx.update(Project::init_settings);
 389    let fs = FakeFs::new(cx.background_executor.clone());
 390
 391    fs.insert_tree(
 392        "/test",
 393        json!({
 394            "src": {
 395                "lib.rs": "fn one() -> usize { 1 }",
 396                "main.rs": "
 397                    use crate::one;
 398                    fn main() { one(); }
 399                ".unindent(),
 400            }
 401        }),
 402    )
 403    .await;
 404
 405    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 406    slash_command_registry.register_command(FileSlashCommand, false);
 407
 408    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 409    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 410    let context = cx.new_model(|cx| {
 411        Context::local(
 412            registry.clone(),
 413            None,
 414            None,
 415            prompt_builder.clone(),
 416            Arc::new(SlashCommandWorkingSet::default()),
 417            Arc::new(ToolWorkingSet::default()),
 418            cx,
 419        )
 420    });
 421
 422    #[derive(Default)]
 423    struct ContextRanges {
 424        parsed_commands: HashSet<Range<language::Anchor>>,
 425        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 426        output_sections: HashSet<Range<language::Anchor>>,
 427    }
 428
 429    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 430    context.update(cx, |_, cx| {
 431        cx.subscribe(&context, {
 432            let context_ranges = context_ranges.clone();
 433            move |context, _, event, _| {
 434                let mut context_ranges = context_ranges.borrow_mut();
 435                match event {
 436                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 437                        let command = context.invoked_slash_command(command_id).unwrap();
 438                        context_ranges
 439                            .command_outputs
 440                            .insert(*command_id, command.range.clone());
 441                    }
 442                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 443                        for range in removed {
 444                            context_ranges.parsed_commands.remove(range);
 445                        }
 446                        for command in updated {
 447                            context_ranges
 448                                .parsed_commands
 449                                .insert(command.source_range.clone());
 450                        }
 451                    }
 452                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 453                        context_ranges.output_sections.insert(section.range.clone());
 454                    }
 455                    _ => {}
 456                }
 457            }
 458        })
 459        .detach();
 460    });
 461
 462    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 463
 464    // Insert a slash command
 465    buffer.update(cx, |buffer, cx| {
 466        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 467    });
 468    assert_text_and_context_ranges(
 469        &buffer,
 470        &context_ranges,
 471        &"
 472        «/file src/lib.rs»"
 473            .unindent(),
 474        cx,
 475    );
 476
 477    // Edit the argument of the slash command.
 478    buffer.update(cx, |buffer, cx| {
 479        let edit_offset = buffer.text().find("lib.rs").unwrap();
 480        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 481    });
 482    assert_text_and_context_ranges(
 483        &buffer,
 484        &context_ranges,
 485        &"
 486        «/file src/main.rs»"
 487            .unindent(),
 488        cx,
 489    );
 490
 491    // Edit the name of the slash command, using one that doesn't exist.
 492    buffer.update(cx, |buffer, cx| {
 493        let edit_offset = buffer.text().find("/file").unwrap();
 494        buffer.edit(
 495            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 496            None,
 497            cx,
 498        );
 499    });
 500    assert_text_and_context_ranges(
 501        &buffer,
 502        &context_ranges,
 503        &"
 504        /unknown src/main.rs"
 505            .unindent(),
 506        cx,
 507    );
 508
 509    // Undoing the insertion of an non-existent slash command resorts the previous one.
 510    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 511    assert_text_and_context_ranges(
 512        &buffer,
 513        &context_ranges,
 514        &"
 515        «/file src/main.rs»"
 516            .unindent(),
 517        cx,
 518    );
 519
 520    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 521    context.update(cx, |context, cx| {
 522        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 523        context.insert_command_output(
 524            command_source_range,
 525            "file",
 526            Task::ready(Ok(command_output_rx.boxed())),
 527            true,
 528            cx,
 529        );
 530    });
 531    assert_text_and_context_ranges(
 532        &buffer,
 533        &context_ranges,
 534        &"
 535        ⟦«/file src/main.rs»
 536        …⟧
 537        "
 538        .unindent(),
 539        cx,
 540    );
 541
 542    command_output_tx
 543        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 544            icon: IconName::Ai,
 545            label: "src/main.rs".into(),
 546            metadata: None,
 547        }))
 548        .unwrap();
 549    command_output_tx
 550        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 551        .unwrap();
 552    cx.run_until_parked();
 553    assert_text_and_context_ranges(
 554        &buffer,
 555        &context_ranges,
 556        &"
 557        ⟦«/file src/main.rs»
 558        src/main.rs…⟧
 559        "
 560        .unindent(),
 561        cx,
 562    );
 563
 564    command_output_tx
 565        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 566        .unwrap();
 567    cx.run_until_parked();
 568    assert_text_and_context_ranges(
 569        &buffer,
 570        &context_ranges,
 571        &"
 572        ⟦«/file src/main.rs»
 573        src/main.rs
 574        fn main() {}…⟧
 575        "
 576        .unindent(),
 577        cx,
 578    );
 579
 580    command_output_tx
 581        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 582        .unwrap();
 583    cx.run_until_parked();
 584    assert_text_and_context_ranges(
 585        &buffer,
 586        &context_ranges,
 587        &"
 588        ⟦«/file src/main.rs»
 589        ⟪src/main.rs
 590        fn main() {}⟫…⟧
 591        "
 592        .unindent(),
 593        cx,
 594    );
 595
 596    drop(command_output_tx);
 597    cx.run_until_parked();
 598    assert_text_and_context_ranges(
 599        &buffer,
 600        &context_ranges,
 601        &"
 602        ⟦⟪src/main.rs
 603        fn main() {}⟫⟧
 604        "
 605        .unindent(),
 606        cx,
 607    );
 608
 609    #[track_caller]
 610    fn assert_text_and_context_ranges(
 611        buffer: &Model<Buffer>,
 612        ranges: &RefCell<ContextRanges>,
 613        expected_marked_text: &str,
 614        cx: &mut TestAppContext,
 615    ) {
 616        let mut actual_marked_text = String::new();
 617        buffer.update(cx, |buffer, _| {
 618            struct Endpoint {
 619                offset: usize,
 620                marker: char,
 621            }
 622
 623            let ranges = ranges.borrow();
 624            let mut endpoints = Vec::new();
 625            for range in ranges.command_outputs.values() {
 626                endpoints.push(Endpoint {
 627                    offset: range.start.to_offset(buffer),
 628                    marker: '',
 629                });
 630            }
 631            for range in ranges.parsed_commands.iter() {
 632                endpoints.push(Endpoint {
 633                    offset: range.start.to_offset(buffer),
 634                    marker: '«',
 635                });
 636            }
 637            for range in ranges.output_sections.iter() {
 638                endpoints.push(Endpoint {
 639                    offset: range.start.to_offset(buffer),
 640                    marker: '',
 641                });
 642            }
 643
 644            for range in ranges.output_sections.iter() {
 645                endpoints.push(Endpoint {
 646                    offset: range.end.to_offset(buffer),
 647                    marker: '',
 648                });
 649            }
 650            for range in ranges.parsed_commands.iter() {
 651                endpoints.push(Endpoint {
 652                    offset: range.end.to_offset(buffer),
 653                    marker: '»',
 654                });
 655            }
 656            for range in ranges.command_outputs.values() {
 657                endpoints.push(Endpoint {
 658                    offset: range.end.to_offset(buffer),
 659                    marker: '',
 660                });
 661            }
 662
 663            endpoints.sort_by_key(|endpoint| endpoint.offset);
 664            let mut offset = 0;
 665            for endpoint in endpoints {
 666                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 667                actual_marked_text.push(endpoint.marker);
 668                offset = endpoint.offset;
 669            }
 670            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 671        });
 672
 673        assert_eq!(actual_marked_text, expected_marked_text);
 674    }
 675}
 676
 677#[gpui::test]
 678async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 679    cx.update(prompt_library::init);
 680    let mut settings_store = cx.update(SettingsStore::test);
 681    cx.update(|cx| {
 682        settings_store
 683            .set_user_settings(
 684                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 685                cx,
 686            )
 687            .unwrap()
 688    });
 689    cx.set_global(settings_store);
 690    cx.update(language::init);
 691    cx.update(Project::init_settings);
 692    let fs = FakeFs::new(cx.executor());
 693    let project = Project::test(fs, [Path::new("/root")], cx).await;
 694    cx.update(LanguageModelRegistry::test);
 695
 696    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 697
 698    // Create a new context
 699    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 700    let context = cx.new_model(|cx| {
 701        Context::local(
 702            registry.clone(),
 703            Some(project),
 704            None,
 705            prompt_builder.clone(),
 706            Arc::new(SlashCommandWorkingSet::default()),
 707            Arc::new(ToolWorkingSet::default()),
 708            cx,
 709        )
 710    });
 711
 712    // Insert an assistant message to simulate a response.
 713    let assistant_message_id = context.update(cx, |context, cx| {
 714        let user_message_id = context.messages(cx).next().unwrap().id;
 715        context
 716            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 717            .unwrap()
 718            .id
 719    });
 720
 721    // No edit tags
 722    edit(
 723        &context,
 724        "
 725
 726        «one
 727        two
 728        »",
 729        cx,
 730    );
 731    expect_patches(
 732        &context,
 733        "
 734
 735        one
 736        two
 737        ",
 738        &[],
 739        cx,
 740    );
 741
 742    // Partial edit step tag is added
 743    edit(
 744        &context,
 745        "
 746
 747        one
 748        two
 749        «
 750        <patch»",
 751        cx,
 752    );
 753    expect_patches(
 754        &context,
 755        "
 756
 757        one
 758        two
 759
 760        <patch",
 761        &[],
 762        cx,
 763    );
 764
 765    // The rest of the step tag is added. The unclosed
 766    // step is treated as incomplete.
 767    edit(
 768        &context,
 769        "
 770
 771        one
 772        two
 773
 774        <patch«>
 775        <edit>»",
 776        cx,
 777    );
 778    expect_patches(
 779        &context,
 780        "
 781
 782        one
 783        two
 784
 785        «<patch>
 786        <edit>»",
 787        &[&[]],
 788        cx,
 789    );
 790
 791    // The full patch is added
 792    edit(
 793        &context,
 794        "
 795
 796        one
 797        two
 798
 799        <patch>
 800        <edit>«
 801        <description>add a `two` function</description>
 802        <path>src/lib.rs</path>
 803        <operation>insert_after</operation>
 804        <old_text>fn one</old_text>
 805        <new_text>
 806        fn two() {}
 807        </new_text>
 808        </edit>
 809        </patch>
 810
 811        also,»",
 812        cx,
 813    );
 814    expect_patches(
 815        &context,
 816        "
 817
 818        one
 819        two
 820
 821        «<patch>
 822        <edit>
 823        <description>add a `two` function</description>
 824        <path>src/lib.rs</path>
 825        <operation>insert_after</operation>
 826        <old_text>fn one</old_text>
 827        <new_text>
 828        fn two() {}
 829        </new_text>
 830        </edit>
 831        </patch>
 832        »
 833        also,",
 834        &[&[AssistantEdit {
 835            path: "src/lib.rs".into(),
 836            kind: AssistantEditKind::InsertAfter {
 837                old_text: "fn one".into(),
 838                new_text: "fn two() {}".into(),
 839                description: Some("add a `two` function".into()),
 840            },
 841        }]],
 842        cx,
 843    );
 844
 845    // The step is manually edited.
 846    edit(
 847        &context,
 848        "
 849
 850        one
 851        two
 852
 853        <patch>
 854        <edit>
 855        <description>add a `two` function</description>
 856        <path>src/lib.rs</path>
 857        <operation>insert_after</operation>
 858        <old_text>«fn zero»</old_text>
 859        <new_text>
 860        fn two() {}
 861        </new_text>
 862        </edit>
 863        </patch>
 864
 865        also,",
 866        cx,
 867    );
 868    expect_patches(
 869        &context,
 870        "
 871
 872        one
 873        two
 874
 875        «<patch>
 876        <edit>
 877        <description>add a `two` function</description>
 878        <path>src/lib.rs</path>
 879        <operation>insert_after</operation>
 880        <old_text>fn zero</old_text>
 881        <new_text>
 882        fn two() {}
 883        </new_text>
 884        </edit>
 885        </patch>
 886        »
 887        also,",
 888        &[&[AssistantEdit {
 889            path: "src/lib.rs".into(),
 890            kind: AssistantEditKind::InsertAfter {
 891                old_text: "fn zero".into(),
 892                new_text: "fn two() {}".into(),
 893                description: Some("add a `two` function".into()),
 894            },
 895        }]],
 896        cx,
 897    );
 898
 899    // When setting the message role to User, the steps are cleared.
 900    context.update(cx, |context, cx| {
 901        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 902        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 903    });
 904    expect_patches(
 905        &context,
 906        "
 907
 908        one
 909        two
 910
 911        <patch>
 912        <edit>
 913        <description>add a `two` function</description>
 914        <path>src/lib.rs</path>
 915        <operation>insert_after</operation>
 916        <old_text>fn zero</old_text>
 917        <new_text>
 918        fn two() {}
 919        </new_text>
 920        </edit>
 921        </patch>
 922
 923        also,",
 924        &[],
 925        cx,
 926    );
 927
 928    // When setting the message role back to Assistant, the steps are reparsed.
 929    context.update(cx, |context, cx| {
 930        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 931    });
 932    expect_patches(
 933        &context,
 934        "
 935
 936        one
 937        two
 938
 939        «<patch>
 940        <edit>
 941        <description>add a `two` function</description>
 942        <path>src/lib.rs</path>
 943        <operation>insert_after</operation>
 944        <old_text>fn zero</old_text>
 945        <new_text>
 946        fn two() {}
 947        </new_text>
 948        </edit>
 949        </patch>
 950        »
 951        also,",
 952        &[&[AssistantEdit {
 953            path: "src/lib.rs".into(),
 954            kind: AssistantEditKind::InsertAfter {
 955                old_text: "fn zero".into(),
 956                new_text: "fn two() {}".into(),
 957                description: Some("add a `two` function".into()),
 958            },
 959        }]],
 960        cx,
 961    );
 962
 963    // Ensure steps are re-parsed when deserializing.
 964    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 965    let deserialized_context = cx.new_model(|cx| {
 966        Context::deserialize(
 967            serialized_context,
 968            Default::default(),
 969            registry.clone(),
 970            prompt_builder.clone(),
 971            Arc::new(SlashCommandWorkingSet::default()),
 972            Arc::new(ToolWorkingSet::default()),
 973            None,
 974            None,
 975            cx,
 976        )
 977    });
 978    expect_patches(
 979        &deserialized_context,
 980        "
 981
 982        one
 983        two
 984
 985        «<patch>
 986        <edit>
 987        <description>add a `two` function</description>
 988        <path>src/lib.rs</path>
 989        <operation>insert_after</operation>
 990        <old_text>fn zero</old_text>
 991        <new_text>
 992        fn two() {}
 993        </new_text>
 994        </edit>
 995        </patch>
 996        »
 997        also,",
 998        &[&[AssistantEdit {
 999            path: "src/lib.rs".into(),
1000            kind: AssistantEditKind::InsertAfter {
1001                old_text: "fn zero".into(),
1002                new_text: "fn two() {}".into(),
1003                description: Some("add a `two` function".into()),
1004            },
1005        }]],
1006        cx,
1007    );
1008
1009    fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
1010        context.update(cx, |context, cx| {
1011            context.buffer.update(cx, |buffer, cx| {
1012                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1013            });
1014        });
1015        cx.executor().run_until_parked();
1016    }
1017
1018    #[track_caller]
1019    fn expect_patches(
1020        context: &Model<Context>,
1021        expected_marked_text: &str,
1022        expected_suggestions: &[&[AssistantEdit]],
1023        cx: &mut TestAppContext,
1024    ) {
1025        let expected_marked_text = expected_marked_text.unindent();
1026        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1027
1028        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1029            context.buffer.read_with(cx, |buffer, _| {
1030                let ranges = context
1031                    .patches
1032                    .iter()
1033                    .map(|entry| entry.range.to_offset(buffer))
1034                    .collect::<Vec<_>>();
1035                (
1036                    buffer.text(),
1037                    ranges,
1038                    context
1039                        .patches
1040                        .iter()
1041                        .map(|step| step.edits.clone())
1042                        .collect::<Vec<_>>(),
1043                )
1044            })
1045        });
1046
1047        assert_eq!(buffer_text, expected_text);
1048
1049        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1050        assert_eq!(actual_marked_text, expected_marked_text);
1051
1052        assert_eq!(
1053            patches
1054                .iter()
1055                .map(|patch| {
1056                    patch
1057                        .iter()
1058                        .map(|edit| {
1059                            let edit = edit.as_ref().unwrap();
1060                            AssistantEdit {
1061                                path: edit.path.clone(),
1062                                kind: edit.kind.clone(),
1063                            }
1064                        })
1065                        .collect::<Vec<_>>()
1066                })
1067                .collect::<Vec<_>>(),
1068            expected_suggestions
1069        );
1070    }
1071}
1072
1073#[gpui::test]
1074async fn test_serialization(cx: &mut TestAppContext) {
1075    let settings_store = cx.update(SettingsStore::test);
1076    cx.set_global(settings_store);
1077    cx.update(LanguageModelRegistry::test);
1078    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1079    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1080    let context = cx.new_model(|cx| {
1081        Context::local(
1082            registry.clone(),
1083            None,
1084            None,
1085            prompt_builder.clone(),
1086            Arc::new(SlashCommandWorkingSet::default()),
1087            Arc::new(ToolWorkingSet::default()),
1088            cx,
1089        )
1090    });
1091    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1092    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1093    let message_1 = context.update(cx, |context, cx| {
1094        context
1095            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1096            .unwrap()
1097    });
1098    let message_2 = context.update(cx, |context, cx| {
1099        context
1100            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1101            .unwrap()
1102    });
1103    buffer.update(cx, |buffer, cx| {
1104        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1105        buffer.finalize_last_transaction();
1106    });
1107    let _message_3 = context.update(cx, |context, cx| {
1108        context
1109            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1110            .unwrap()
1111    });
1112    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1113    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1114    assert_eq!(
1115        cx.read(|cx| messages(&context, cx)),
1116        [
1117            (message_0, Role::User, 0..2),
1118            (message_1.id, Role::Assistant, 2..6),
1119            (message_2.id, Role::System, 6..6),
1120        ]
1121    );
1122
1123    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1124    let deserialized_context = cx.new_model(|cx| {
1125        Context::deserialize(
1126            serialized_context,
1127            Default::default(),
1128            registry.clone(),
1129            prompt_builder.clone(),
1130            Arc::new(SlashCommandWorkingSet::default()),
1131            Arc::new(ToolWorkingSet::default()),
1132            None,
1133            None,
1134            cx,
1135        )
1136    });
1137    let deserialized_buffer =
1138        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1139    assert_eq!(
1140        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1141        "a\nb\nc\n"
1142    );
1143    assert_eq!(
1144        cx.read(|cx| messages(&deserialized_context, cx)),
1145        [
1146            (message_0, Role::User, 0..2),
1147            (message_1.id, Role::Assistant, 2..6),
1148            (message_2.id, Role::System, 6..6),
1149        ]
1150    );
1151}
1152
1153#[gpui::test(iterations = 100)]
1154async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1155    let min_peers = env::var("MIN_PEERS")
1156        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1157        .unwrap_or(2);
1158    let max_peers = env::var("MAX_PEERS")
1159        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1160        .unwrap_or(5);
1161    let operations = env::var("OPERATIONS")
1162        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1163        .unwrap_or(50);
1164
1165    let settings_store = cx.update(SettingsStore::test);
1166    cx.set_global(settings_store);
1167    cx.update(LanguageModelRegistry::test);
1168
1169    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1170    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1171    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1172    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1173
1174    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1175    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1176    let mut contexts = Vec::new();
1177
1178    let num_peers = rng.gen_range(min_peers..=max_peers);
1179    let context_id = ContextId::new();
1180    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1181    for i in 0..num_peers {
1182        let context = cx.new_model(|cx| {
1183            Context::new(
1184                context_id.clone(),
1185                i as ReplicaId,
1186                language::Capability::ReadWrite,
1187                registry.clone(),
1188                prompt_builder.clone(),
1189                Arc::new(SlashCommandWorkingSet::default()),
1190                Arc::new(ToolWorkingSet::default()),
1191                None,
1192                None,
1193                cx,
1194            )
1195        });
1196
1197        cx.update(|cx| {
1198            cx.subscribe(&context, {
1199                let network = network.clone();
1200                move |_, event, _| {
1201                    if let ContextEvent::Operation(op) = event {
1202                        network
1203                            .lock()
1204                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1205                    }
1206                }
1207            })
1208            .detach();
1209        });
1210
1211        contexts.push(context);
1212        network.lock().add_peer(i as ReplicaId);
1213    }
1214
1215    let mut mutation_count = operations;
1216
1217    while mutation_count > 0
1218        || !network.lock().is_idle()
1219        || network.lock().contains_disconnected_peers()
1220    {
1221        let context_index = rng.gen_range(0..contexts.len());
1222        let context = &contexts[context_index];
1223
1224        match rng.gen_range(0..100) {
1225            0..=29 if mutation_count > 0 => {
1226                log::info!("Context {}: edit buffer", context_index);
1227                context.update(cx, |context, cx| {
1228                    context
1229                        .buffer
1230                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1231                });
1232                mutation_count -= 1;
1233            }
1234            30..=44 if mutation_count > 0 => {
1235                context.update(cx, |context, cx| {
1236                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1237                    log::info!("Context {}: split message at {:?}", context_index, range);
1238                    context.split_message(range, cx);
1239                });
1240                mutation_count -= 1;
1241            }
1242            45..=59 if mutation_count > 0 => {
1243                context.update(cx, |context, cx| {
1244                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1245                        let role = *[Role::User, Role::Assistant, Role::System]
1246                            .choose(&mut rng)
1247                            .unwrap();
1248                        log::info!(
1249                            "Context {}: insert message after {:?} with {:?}",
1250                            context_index,
1251                            message.id,
1252                            role
1253                        );
1254                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1255                    }
1256                });
1257                mutation_count -= 1;
1258            }
1259            60..=74 if mutation_count > 0 => {
1260                context.update(cx, |context, cx| {
1261                    let command_text = "/".to_string()
1262                        + slash_commands
1263                            .command_names()
1264                            .choose(&mut rng)
1265                            .unwrap()
1266                            .clone()
1267                            .as_ref();
1268
1269                    let command_range = context.buffer.update(cx, |buffer, cx| {
1270                        let offset = buffer.random_byte_range(0, &mut rng).start;
1271                        buffer.edit(
1272                            [(offset..offset, format!("\n{}\n", command_text))],
1273                            None,
1274                            cx,
1275                        );
1276                        offset + 1..offset + 1 + command_text.len()
1277                    });
1278
1279                    let output_text = RandomCharIter::new(&mut rng)
1280                        .filter(|c| *c != '\r')
1281                        .take(10)
1282                        .collect::<String>();
1283
1284                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1285                        role: Role::User,
1286                        merge_same_roles: true,
1287                    })];
1288
1289                    let num_sections = rng.gen_range(0..=3);
1290                    let mut section_start = 0;
1291                    for _ in 0..num_sections {
1292                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1293                        while !output_text.is_char_boundary(section_end) {
1294                            section_end += 1;
1295                        }
1296                        events.push(Ok(SlashCommandEvent::StartSection {
1297                            icon: IconName::Ai,
1298                            label: "section".into(),
1299                            metadata: None,
1300                        }));
1301                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1302                            text: output_text[section_start..section_end].to_string(),
1303                            run_commands_in_text: false,
1304                        })));
1305                        events.push(Ok(SlashCommandEvent::EndSection));
1306                        section_start = section_end;
1307                    }
1308
1309                    if section_start < output_text.len() {
1310                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1311                            text: output_text[section_start..].to_string(),
1312                            run_commands_in_text: false,
1313                        })));
1314                    }
1315
1316                    log::info!(
1317                        "Context {}: insert slash command output at {:?} with {:?} events",
1318                        context_index,
1319                        command_range,
1320                        events.len()
1321                    );
1322
1323                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1324                        ..context.buffer.read(cx).anchor_after(command_range.end);
1325                    context.insert_command_output(
1326                        command_range,
1327                        "/command",
1328                        Task::ready(Ok(stream::iter(events).boxed())),
1329                        true,
1330                        cx,
1331                    );
1332                });
1333                cx.run_until_parked();
1334                mutation_count -= 1;
1335            }
1336            75..=84 if mutation_count > 0 => {
1337                context.update(cx, |context, cx| {
1338                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1339                        let new_status = match rng.gen_range(0..3) {
1340                            0 => MessageStatus::Done,
1341                            1 => MessageStatus::Pending,
1342                            _ => MessageStatus::Error(SharedString::from("Random error")),
1343                        };
1344                        log::info!(
1345                            "Context {}: update message {:?} status to {:?}",
1346                            context_index,
1347                            message.id,
1348                            new_status
1349                        );
1350                        context.update_metadata(message.id, cx, |metadata| {
1351                            metadata.status = new_status;
1352                        });
1353                    }
1354                });
1355                mutation_count -= 1;
1356            }
1357            _ => {
1358                let replica_id = context_index as ReplicaId;
1359                if network.lock().is_disconnected(replica_id) {
1360                    network.lock().reconnect_peer(replica_id, 0);
1361
1362                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1363                        let host_context = &contexts[0].read(cx);
1364                        let guest_context = context.read(cx);
1365                        (
1366                            guest_context.serialize_ops(&host_context.version(cx), cx),
1367                            host_context.serialize_ops(&guest_context.version(cx), cx),
1368                        )
1369                    });
1370                    let ops_to_send = ops_to_send.await;
1371                    let ops_to_receive = ops_to_receive
1372                        .await
1373                        .into_iter()
1374                        .map(ContextOperation::from_proto)
1375                        .collect::<Result<Vec<_>>>()
1376                        .unwrap();
1377                    log::info!(
1378                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1379                        context_index,
1380                        ops_to_send.len(),
1381                        ops_to_receive.len()
1382                    );
1383
1384                    network.lock().broadcast(replica_id, ops_to_send);
1385                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1386                } else if rng.gen_bool(0.1) && replica_id != 0 {
1387                    log::info!("Context {}: disconnecting", context_index);
1388                    network.lock().disconnect_peer(replica_id);
1389                } else if network.lock().has_unreceived(replica_id) {
1390                    log::info!("Context {}: applying operations", context_index);
1391                    let ops = network.lock().receive(replica_id);
1392                    let ops = ops
1393                        .into_iter()
1394                        .map(ContextOperation::from_proto)
1395                        .collect::<Result<Vec<_>>>()
1396                        .unwrap();
1397                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1398                }
1399            }
1400        }
1401    }
1402
1403    cx.read(|cx| {
1404        let first_context = contexts[0].read(cx);
1405        for context in &contexts[1..] {
1406            let context = context.read(cx);
1407            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1408            assert_eq!(
1409                context.buffer.read(cx).text(),
1410                first_context.buffer.read(cx).text(),
1411                "Context {} text != Context 0 text",
1412                context.buffer.read(cx).replica_id()
1413            );
1414            assert_eq!(
1415                context.message_anchors,
1416                first_context.message_anchors,
1417                "Context {} messages != Context 0 messages",
1418                context.buffer.read(cx).replica_id()
1419            );
1420            assert_eq!(
1421                context.messages_metadata,
1422                first_context.messages_metadata,
1423                "Context {} message metadata != Context 0 message metadata",
1424                context.buffer.read(cx).replica_id()
1425            );
1426            assert_eq!(
1427                context.slash_command_output_sections,
1428                first_context.slash_command_output_sections,
1429                "Context {} slash command output sections != Context 0 slash command output sections",
1430                context.buffer.read(cx).replica_id()
1431            );
1432        }
1433    });
1434}
1435
1436#[gpui::test]
1437fn test_mark_cache_anchors(cx: &mut AppContext) {
1438    let settings_store = SettingsStore::test(cx);
1439    LanguageModelRegistry::test(cx);
1440    cx.set_global(settings_store);
1441    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1442    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1443    let context = cx.new_model(|cx| {
1444        Context::local(
1445            registry,
1446            None,
1447            None,
1448            prompt_builder.clone(),
1449            Arc::new(SlashCommandWorkingSet::default()),
1450            Arc::new(ToolWorkingSet::default()),
1451            cx,
1452        )
1453    });
1454    let buffer = context.read(cx).buffer.clone();
1455
1456    // Create a test cache configuration
1457    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1458        max_cache_anchors: 3,
1459        should_speculate: true,
1460        min_total_token: 10,
1461    });
1462
1463    let message_1 = context.read(cx).message_anchors[0].clone();
1464
1465    context.update(cx, |context, cx| {
1466        context.mark_cache_anchors(cache_configuration, false, cx)
1467    });
1468
1469    assert_eq!(
1470        messages_cache(&context, cx)
1471            .iter()
1472            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1473            .count(),
1474        0,
1475        "Empty messages should not have any cache anchors."
1476    );
1477
1478    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1479    let message_2 = context
1480        .update(cx, |context, cx| {
1481            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1482        })
1483        .unwrap();
1484
1485    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1486    let message_3 = context
1487        .update(cx, |context, cx| {
1488            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1489        })
1490        .unwrap();
1491    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1492
1493    context.update(cx, |context, cx| {
1494        context.mark_cache_anchors(cache_configuration, false, cx)
1495    });
1496    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1497    assert_eq!(
1498        messages_cache(&context, cx)
1499            .iter()
1500            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1501            .count(),
1502        0,
1503        "Messages should not be marked for cache before going over the token minimum."
1504    );
1505    context.update(cx, |context, _| {
1506        context.token_count = Some(20);
1507    });
1508
1509    context.update(cx, |context, cx| {
1510        context.mark_cache_anchors(cache_configuration, true, cx)
1511    });
1512    assert_eq!(
1513        messages_cache(&context, cx)
1514            .iter()
1515            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1516            .collect::<Vec<bool>>(),
1517        vec![true, true, false],
1518        "Last message should not be an anchor on speculative request."
1519    );
1520
1521    context
1522        .update(cx, |context, cx| {
1523            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1524        })
1525        .unwrap();
1526
1527    context.update(cx, |context, cx| {
1528        context.mark_cache_anchors(cache_configuration, false, cx)
1529    });
1530    assert_eq!(
1531        messages_cache(&context, cx)
1532            .iter()
1533            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1534            .collect::<Vec<bool>>(),
1535        vec![false, true, true, false],
1536        "Most recent message should also be cached if not a speculative request."
1537    );
1538    context.update(cx, |context, cx| {
1539        context.update_cache_status_for_completion(cx)
1540    });
1541    assert_eq!(
1542        messages_cache(&context, cx)
1543            .iter()
1544            .map(|(_, cache)| cache
1545                .as_ref()
1546                .map_or(None, |cache| Some(cache.status.clone())))
1547            .collect::<Vec<Option<CacheStatus>>>(),
1548        vec![
1549            Some(CacheStatus::Cached),
1550            Some(CacheStatus::Cached),
1551            Some(CacheStatus::Cached),
1552            None
1553        ],
1554        "All user messages prior to anchor should be marked as cached."
1555    );
1556
1557    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1558    context.update(cx, |context, cx| {
1559        context.mark_cache_anchors(cache_configuration, false, cx)
1560    });
1561    assert_eq!(
1562        messages_cache(&context, cx)
1563            .iter()
1564            .map(|(_, cache)| cache
1565                .as_ref()
1566                .map_or(None, |cache| Some(cache.status.clone())))
1567            .collect::<Vec<Option<CacheStatus>>>(),
1568        vec![
1569            Some(CacheStatus::Cached),
1570            Some(CacheStatus::Cached),
1571            Some(CacheStatus::Pending),
1572            None
1573        ],
1574        "Modifying a message should invalidate it's cache but leave previous messages."
1575    );
1576    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1577    context.update(cx, |context, cx| {
1578        context.mark_cache_anchors(cache_configuration, false, cx)
1579    });
1580    assert_eq!(
1581        messages_cache(&context, cx)
1582            .iter()
1583            .map(|(_, cache)| cache
1584                .as_ref()
1585                .map_or(None, |cache| Some(cache.status.clone())))
1586            .collect::<Vec<Option<CacheStatus>>>(),
1587        vec![
1588            Some(CacheStatus::Pending),
1589            Some(CacheStatus::Pending),
1590            Some(CacheStatus::Pending),
1591            None
1592        ],
1593        "Modifying a message should invalidate all future messages."
1594    );
1595}
1596
1597fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1598    context
1599        .read(cx)
1600        .messages(cx)
1601        .map(|message| (message.id, message.role, message.offset_range))
1602        .collect()
1603}
1604
1605fn messages_cache(
1606    context: &Model<Context>,
1607    cx: &AppContext,
1608) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1609    context
1610        .read(cx)
1611        .messages(cx)
1612        .map(|message| (message.id, message.cache.clone()))
1613        .collect()
1614}
1615
1616#[derive(Clone)]
1617struct FakeSlashCommand(String);
1618
1619impl SlashCommand for FakeSlashCommand {
1620    fn name(&self) -> String {
1621        self.0.clone()
1622    }
1623
1624    fn description(&self) -> String {
1625        format!("Fake slash command: {}", self.0)
1626    }
1627
1628    fn menu_text(&self) -> String {
1629        format!("Run fake command: {}", self.0)
1630    }
1631
1632    fn complete_argument(
1633        self: Arc<Self>,
1634        _arguments: &[String],
1635        _cancel: Arc<AtomicBool>,
1636        _workspace: Option<WeakView<Workspace>>,
1637        _cx: &mut WindowContext,
1638    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1639        Task::ready(Ok(vec![]))
1640    }
1641
1642    fn requires_argument(&self) -> bool {
1643        false
1644    }
1645
1646    fn run(
1647        self: Arc<Self>,
1648        _arguments: &[String],
1649        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1650        _context_buffer: BufferSnapshot,
1651        _workspace: WeakView<Workspace>,
1652        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1653        _cx: &mut WindowContext,
1654    ) -> Task<SlashCommandResult> {
1655        Task::ready(Ok(SlashCommandOutput {
1656            text: format!("Executed fake command: {}", self.0),
1657            sections: vec![],
1658            run_commands_in_text: false,
1659        }
1660        .to_event_stream()))
1661    }
1662}