context_tests.rs

   1use crate::{
   2    AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation,
   3    InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  20use parking_lot::Mutex;
  21use pretty_assertions::assert_eq;
  22use project::Project;
  23use prompt_store::PromptBuilder;
  24use rand::prelude::*;
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{
  28    cell::RefCell,
  29    env,
  30    ops::Range,
  31    path::Path,
  32    rc::Rc,
  33    sync::{Arc, atomic::AtomicBool},
  34};
  35use text::{ReplicaId, ToOffset, network::Network};
  36use ui::{IconName, Window};
  37use unindent::Unindent;
  38use util::RandomCharIter;
  39use workspace::Workspace;
  40
  41#[gpui::test]
  42fn test_inserting_and_removing_messages(cx: &mut App) {
  43    init_test(cx);
  44
  45    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  46    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  47    let context = cx.new(|cx| {
  48        AssistantContext::local(
  49            registry,
  50            None,
  51            None,
  52            prompt_builder.clone(),
  53            Arc::new(SlashCommandWorkingSet::default()),
  54            cx,
  55        )
  56    });
  57    let buffer = context.read(cx).buffer.clone();
  58
  59    let message_1 = context.read(cx).message_anchors[0].clone();
  60    assert_eq!(
  61        messages(&context, cx),
  62        vec![(message_1.id, Role::User, 0..0)]
  63    );
  64
  65    let message_2 = context.update(cx, |context, cx| {
  66        context
  67            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  68            .unwrap()
  69    });
  70    assert_eq!(
  71        messages(&context, cx),
  72        vec![
  73            (message_1.id, Role::User, 0..1),
  74            (message_2.id, Role::Assistant, 1..1)
  75        ]
  76    );
  77
  78    buffer.update(cx, |buffer, cx| {
  79        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  80    });
  81    assert_eq!(
  82        messages(&context, cx),
  83        vec![
  84            (message_1.id, Role::User, 0..2),
  85            (message_2.id, Role::Assistant, 2..3)
  86        ]
  87    );
  88
  89    let message_3 = context.update(cx, |context, cx| {
  90        context
  91            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  92            .unwrap()
  93    });
  94    assert_eq!(
  95        messages(&context, cx),
  96        vec![
  97            (message_1.id, Role::User, 0..2),
  98            (message_2.id, Role::Assistant, 2..4),
  99            (message_3.id, Role::User, 4..4)
 100        ]
 101    );
 102
 103    let message_4 = context.update(cx, |context, cx| {
 104        context
 105            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 106            .unwrap()
 107    });
 108    assert_eq!(
 109        messages(&context, cx),
 110        vec![
 111            (message_1.id, Role::User, 0..2),
 112            (message_2.id, Role::Assistant, 2..4),
 113            (message_4.id, Role::User, 4..5),
 114            (message_3.id, Role::User, 5..5),
 115        ]
 116    );
 117
 118    buffer.update(cx, |buffer, cx| {
 119        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 120    });
 121    assert_eq!(
 122        messages(&context, cx),
 123        vec![
 124            (message_1.id, Role::User, 0..2),
 125            (message_2.id, Role::Assistant, 2..4),
 126            (message_4.id, Role::User, 4..6),
 127            (message_3.id, Role::User, 6..7),
 128        ]
 129    );
 130
 131    // Deleting across message boundaries merges the messages.
 132    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 133    assert_eq!(
 134        messages(&context, cx),
 135        vec![
 136            (message_1.id, Role::User, 0..3),
 137            (message_3.id, Role::User, 3..4),
 138        ]
 139    );
 140
 141    // Undoing the deletion should also undo the merge.
 142    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 143    assert_eq!(
 144        messages(&context, cx),
 145        vec![
 146            (message_1.id, Role::User, 0..2),
 147            (message_2.id, Role::Assistant, 2..4),
 148            (message_4.id, Role::User, 4..6),
 149            (message_3.id, Role::User, 6..7),
 150        ]
 151    );
 152
 153    // Redoing the deletion should also redo the merge.
 154    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 155    assert_eq!(
 156        messages(&context, cx),
 157        vec![
 158            (message_1.id, Role::User, 0..3),
 159            (message_3.id, Role::User, 3..4),
 160        ]
 161    );
 162
 163    // Ensure we can still insert after a merged message.
 164    let message_5 = context.update(cx, |context, cx| {
 165        context
 166            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 167            .unwrap()
 168    });
 169    assert_eq!(
 170        messages(&context, cx),
 171        vec![
 172            (message_1.id, Role::User, 0..3),
 173            (message_5.id, Role::System, 3..4),
 174            (message_3.id, Role::User, 4..5)
 175        ]
 176    );
 177}
 178
 179#[gpui::test]
 180fn test_message_splitting(cx: &mut App) {
 181    init_test(cx);
 182
 183    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 184
 185    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 186    let context = cx.new(|cx| {
 187        AssistantContext::local(
 188            registry.clone(),
 189            None,
 190            None,
 191            prompt_builder.clone(),
 192            Arc::new(SlashCommandWorkingSet::default()),
 193            cx,
 194        )
 195    });
 196    let buffer = context.read(cx).buffer.clone();
 197
 198    let message_1 = context.read(cx).message_anchors[0].clone();
 199    assert_eq!(
 200        messages(&context, cx),
 201        vec![(message_1.id, Role::User, 0..0)]
 202    );
 203
 204    buffer.update(cx, |buffer, cx| {
 205        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 206    });
 207
 208    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 209    let message_2 = message_2.unwrap();
 210
 211    // We recycle newlines in the middle of a split message
 212    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 213    assert_eq!(
 214        messages(&context, cx),
 215        vec![
 216            (message_1.id, Role::User, 0..4),
 217            (message_2.id, Role::User, 4..16),
 218        ]
 219    );
 220
 221    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 222    let message_3 = message_3.unwrap();
 223
 224    // We don't recycle newlines at the end of a split message
 225    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 226    assert_eq!(
 227        messages(&context, cx),
 228        vec![
 229            (message_1.id, Role::User, 0..4),
 230            (message_3.id, Role::User, 4..5),
 231            (message_2.id, Role::User, 5..17),
 232        ]
 233    );
 234
 235    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 236    let message_4 = message_4.unwrap();
 237    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 238    assert_eq!(
 239        messages(&context, cx),
 240        vec![
 241            (message_1.id, Role::User, 0..4),
 242            (message_3.id, Role::User, 4..5),
 243            (message_2.id, Role::User, 5..9),
 244            (message_4.id, Role::User, 9..17),
 245        ]
 246    );
 247
 248    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 249    let message_5 = message_5.unwrap();
 250    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 251    assert_eq!(
 252        messages(&context, cx),
 253        vec![
 254            (message_1.id, Role::User, 0..4),
 255            (message_3.id, Role::User, 4..5),
 256            (message_2.id, Role::User, 5..9),
 257            (message_4.id, Role::User, 9..10),
 258            (message_5.id, Role::User, 10..18),
 259        ]
 260    );
 261
 262    let (message_6, message_7) =
 263        context.update(cx, |context, cx| context.split_message(14..16, cx));
 264    let message_6 = message_6.unwrap();
 265    let message_7 = message_7.unwrap();
 266    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 267    assert_eq!(
 268        messages(&context, cx),
 269        vec![
 270            (message_1.id, Role::User, 0..4),
 271            (message_3.id, Role::User, 4..5),
 272            (message_2.id, Role::User, 5..9),
 273            (message_4.id, Role::User, 9..10),
 274            (message_5.id, Role::User, 10..14),
 275            (message_6.id, Role::User, 14..17),
 276            (message_7.id, Role::User, 17..19),
 277        ]
 278    );
 279}
 280
 281#[gpui::test]
 282fn test_messages_for_offsets(cx: &mut App) {
 283    init_test(cx);
 284
 285    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 286    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 287    let context = cx.new(|cx| {
 288        AssistantContext::local(
 289            registry,
 290            None,
 291            None,
 292            prompt_builder.clone(),
 293            Arc::new(SlashCommandWorkingSet::default()),
 294            cx,
 295        )
 296    });
 297    let buffer = context.read(cx).buffer.clone();
 298
 299    let message_1 = context.read(cx).message_anchors[0].clone();
 300    assert_eq!(
 301        messages(&context, cx),
 302        vec![(message_1.id, Role::User, 0..0)]
 303    );
 304
 305    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 306    let message_2 = context
 307        .update(cx, |context, cx| {
 308            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 309        })
 310        .unwrap();
 311    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 312
 313    let message_3 = context
 314        .update(cx, |context, cx| {
 315            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 316        })
 317        .unwrap();
 318    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 319
 320    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 321    assert_eq!(
 322        messages(&context, cx),
 323        vec![
 324            (message_1.id, Role::User, 0..4),
 325            (message_2.id, Role::User, 4..8),
 326            (message_3.id, Role::User, 8..11)
 327        ]
 328    );
 329
 330    assert_eq!(
 331        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 332        [message_1.id, message_2.id, message_3.id]
 333    );
 334    assert_eq!(
 335        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 336        [message_1.id, message_3.id]
 337    );
 338
 339    let message_4 = context
 340        .update(cx, |context, cx| {
 341            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 342        })
 343        .unwrap();
 344    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 345    assert_eq!(
 346        messages(&context, cx),
 347        vec![
 348            (message_1.id, Role::User, 0..4),
 349            (message_2.id, Role::User, 4..8),
 350            (message_3.id, Role::User, 8..12),
 351            (message_4.id, Role::User, 12..12)
 352        ]
 353    );
 354    assert_eq!(
 355        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 356        [message_1.id, message_2.id, message_3.id, message_4.id]
 357    );
 358
 359    fn message_ids_for_offsets(
 360        context: &Entity<AssistantContext>,
 361        offsets: &[usize],
 362        cx: &App,
 363    ) -> Vec<MessageId> {
 364        context
 365            .read(cx)
 366            .messages_for_offsets(offsets.iter().copied(), cx)
 367            .into_iter()
 368            .map(|message| message.id)
 369            .collect()
 370    }
 371}
 372
 373#[gpui::test]
 374async fn test_slash_commands(cx: &mut TestAppContext) {
 375    cx.update(init_test);
 376
 377    let fs = FakeFs::new(cx.background_executor.clone());
 378
 379    fs.insert_tree(
 380        "/test",
 381        json!({
 382            "src": {
 383                "lib.rs": "fn one() -> usize { 1 }",
 384                "main.rs": "
 385                    use crate::one;
 386                    fn main() { one(); }
 387                ".unindent(),
 388            }
 389        }),
 390    )
 391    .await;
 392
 393    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 394    slash_command_registry.register_command(FileSlashCommand, false);
 395
 396    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 397    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 398    let context = cx.new(|cx| {
 399        AssistantContext::local(
 400            registry.clone(),
 401            None,
 402            None,
 403            prompt_builder.clone(),
 404            Arc::new(SlashCommandWorkingSet::default()),
 405            cx,
 406        )
 407    });
 408
 409    #[derive(Default)]
 410    struct ContextRanges {
 411        parsed_commands: HashSet<Range<language::Anchor>>,
 412        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 413        output_sections: HashSet<Range<language::Anchor>>,
 414    }
 415
 416    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 417    context.update(cx, |_, cx| {
 418        cx.subscribe(&context, {
 419            let context_ranges = context_ranges.clone();
 420            move |context, _, event, _| {
 421                let mut context_ranges = context_ranges.borrow_mut();
 422                match event {
 423                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 424                        let command = context.invoked_slash_command(command_id).unwrap();
 425                        context_ranges
 426                            .command_outputs
 427                            .insert(*command_id, command.range.clone());
 428                    }
 429                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 430                        for range in removed {
 431                            context_ranges.parsed_commands.remove(range);
 432                        }
 433                        for command in updated {
 434                            context_ranges
 435                                .parsed_commands
 436                                .insert(command.source_range.clone());
 437                        }
 438                    }
 439                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 440                        context_ranges.output_sections.insert(section.range.clone());
 441                    }
 442                    _ => {}
 443                }
 444            }
 445        })
 446        .detach();
 447    });
 448
 449    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 450
 451    // Insert a slash command
 452    buffer.update(cx, |buffer, cx| {
 453        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 454    });
 455    assert_text_and_context_ranges(
 456        &buffer,
 457        &context_ranges,
 458        &"
 459        «/file src/lib.rs»"
 460            .unindent(),
 461        cx,
 462    );
 463
 464    // Edit the argument of the slash command.
 465    buffer.update(cx, |buffer, cx| {
 466        let edit_offset = buffer.text().find("lib.rs").unwrap();
 467        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 468    });
 469    assert_text_and_context_ranges(
 470        &buffer,
 471        &context_ranges,
 472        &"
 473        «/file src/main.rs»"
 474            .unindent(),
 475        cx,
 476    );
 477
 478    // Edit the name of the slash command, using one that doesn't exist.
 479    buffer.update(cx, |buffer, cx| {
 480        let edit_offset = buffer.text().find("/file").unwrap();
 481        buffer.edit(
 482            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 483            None,
 484            cx,
 485        );
 486    });
 487    assert_text_and_context_ranges(
 488        &buffer,
 489        &context_ranges,
 490        &"
 491        /unknown src/main.rs"
 492            .unindent(),
 493        cx,
 494    );
 495
 496    // Undoing the insertion of an non-existent slash command resorts the previous one.
 497    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 498    assert_text_and_context_ranges(
 499        &buffer,
 500        &context_ranges,
 501        &"
 502        «/file src/main.rs»"
 503            .unindent(),
 504        cx,
 505    );
 506
 507    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 508    context.update(cx, |context, cx| {
 509        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 510        context.insert_command_output(
 511            command_source_range,
 512            "file",
 513            Task::ready(Ok(command_output_rx.boxed())),
 514            true,
 515            cx,
 516        );
 517    });
 518    assert_text_and_context_ranges(
 519        &buffer,
 520        &context_ranges,
 521        &"
 522        ⟦«/file src/main.rs»
 523        …⟧
 524        "
 525        .unindent(),
 526        cx,
 527    );
 528
 529    command_output_tx
 530        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 531            icon: IconName::Ai,
 532            label: "src/main.rs".into(),
 533            metadata: None,
 534        }))
 535        .unwrap();
 536    command_output_tx
 537        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 538        .unwrap();
 539    cx.run_until_parked();
 540    assert_text_and_context_ranges(
 541        &buffer,
 542        &context_ranges,
 543        &"
 544        ⟦«/file src/main.rs»
 545        src/main.rs…⟧
 546        "
 547        .unindent(),
 548        cx,
 549    );
 550
 551    command_output_tx
 552        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 553        .unwrap();
 554    cx.run_until_parked();
 555    assert_text_and_context_ranges(
 556        &buffer,
 557        &context_ranges,
 558        &"
 559        ⟦«/file src/main.rs»
 560        src/main.rs
 561        fn main() {}…⟧
 562        "
 563        .unindent(),
 564        cx,
 565    );
 566
 567    command_output_tx
 568        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 569        .unwrap();
 570    cx.run_until_parked();
 571    assert_text_and_context_ranges(
 572        &buffer,
 573        &context_ranges,
 574        &"
 575        ⟦«/file src/main.rs»
 576        ⟪src/main.rs
 577        fn main() {}⟫…⟧
 578        "
 579        .unindent(),
 580        cx,
 581    );
 582
 583    drop(command_output_tx);
 584    cx.run_until_parked();
 585    assert_text_and_context_ranges(
 586        &buffer,
 587        &context_ranges,
 588        &"
 589        ⟦⟪src/main.rs
 590        fn main() {}⟫⟧
 591        "
 592        .unindent(),
 593        cx,
 594    );
 595
 596    #[track_caller]
 597    fn assert_text_and_context_ranges(
 598        buffer: &Entity<Buffer>,
 599        ranges: &RefCell<ContextRanges>,
 600        expected_marked_text: &str,
 601        cx: &mut TestAppContext,
 602    ) {
 603        let mut actual_marked_text = String::new();
 604        buffer.update(cx, |buffer, _| {
 605            struct Endpoint {
 606                offset: usize,
 607                marker: char,
 608            }
 609
 610            let ranges = ranges.borrow();
 611            let mut endpoints = Vec::new();
 612            for range in ranges.command_outputs.values() {
 613                endpoints.push(Endpoint {
 614                    offset: range.start.to_offset(buffer),
 615                    marker: '',
 616                });
 617            }
 618            for range in ranges.parsed_commands.iter() {
 619                endpoints.push(Endpoint {
 620                    offset: range.start.to_offset(buffer),
 621                    marker: '«',
 622                });
 623            }
 624            for range in ranges.output_sections.iter() {
 625                endpoints.push(Endpoint {
 626                    offset: range.start.to_offset(buffer),
 627                    marker: '',
 628                });
 629            }
 630
 631            for range in ranges.output_sections.iter() {
 632                endpoints.push(Endpoint {
 633                    offset: range.end.to_offset(buffer),
 634                    marker: '',
 635                });
 636            }
 637            for range in ranges.parsed_commands.iter() {
 638                endpoints.push(Endpoint {
 639                    offset: range.end.to_offset(buffer),
 640                    marker: '»',
 641                });
 642            }
 643            for range in ranges.command_outputs.values() {
 644                endpoints.push(Endpoint {
 645                    offset: range.end.to_offset(buffer),
 646                    marker: '',
 647                });
 648            }
 649
 650            endpoints.sort_by_key(|endpoint| endpoint.offset);
 651            let mut offset = 0;
 652            for endpoint in endpoints {
 653                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 654                actual_marked_text.push(endpoint.marker);
 655                offset = endpoint.offset;
 656            }
 657            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 658        });
 659
 660        assert_eq!(actual_marked_text, expected_marked_text);
 661    }
 662}
 663
 664#[gpui::test]
 665async fn test_serialization(cx: &mut TestAppContext) {
 666    cx.update(init_test);
 667
 668    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 669    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 670    let context = cx.new(|cx| {
 671        AssistantContext::local(
 672            registry.clone(),
 673            None,
 674            None,
 675            prompt_builder.clone(),
 676            Arc::new(SlashCommandWorkingSet::default()),
 677            cx,
 678        )
 679    });
 680    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 681    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
 682    let message_1 = context.update(cx, |context, cx| {
 683        context
 684            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 685            .unwrap()
 686    });
 687    let message_2 = context.update(cx, |context, cx| {
 688        context
 689            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 690            .unwrap()
 691    });
 692    buffer.update(cx, |buffer, cx| {
 693        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 694        buffer.finalize_last_transaction();
 695    });
 696    let _message_3 = context.update(cx, |context, cx| {
 697        context
 698            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 699            .unwrap()
 700    });
 701    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 702    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 703    assert_eq!(
 704        cx.read(|cx| messages(&context, cx)),
 705        [
 706            (message_0, Role::User, 0..2),
 707            (message_1.id, Role::Assistant, 2..6),
 708            (message_2.id, Role::System, 6..6),
 709        ]
 710    );
 711
 712    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 713    let deserialized_context = cx.new(|cx| {
 714        AssistantContext::deserialize(
 715            serialized_context,
 716            Path::new("").into(),
 717            registry.clone(),
 718            prompt_builder.clone(),
 719            Arc::new(SlashCommandWorkingSet::default()),
 720            None,
 721            None,
 722            cx,
 723        )
 724    });
 725    let deserialized_buffer =
 726        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
 727    assert_eq!(
 728        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 729        "a\nb\nc\n"
 730    );
 731    assert_eq!(
 732        cx.read(|cx| messages(&deserialized_context, cx)),
 733        [
 734            (message_0, Role::User, 0..2),
 735            (message_1.id, Role::Assistant, 2..6),
 736            (message_2.id, Role::System, 6..6),
 737        ]
 738    );
 739}
 740
 741#[gpui::test(iterations = 100)]
 742async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 743    cx.update(init_test);
 744
 745    let min_peers = env::var("MIN_PEERS")
 746        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 747        .unwrap_or(2);
 748    let max_peers = env::var("MAX_PEERS")
 749        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 750        .unwrap_or(5);
 751    let operations = env::var("OPERATIONS")
 752        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 753        .unwrap_or(50);
 754
 755    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 756    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 757    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 758    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 759
 760    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 761    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 762    let mut contexts = Vec::new();
 763
 764    let num_peers = rng.gen_range(min_peers..=max_peers);
 765    let context_id = ContextId::new();
 766    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 767    for i in 0..num_peers {
 768        let context = cx.new(|cx| {
 769            AssistantContext::new(
 770                context_id.clone(),
 771                i as ReplicaId,
 772                language::Capability::ReadWrite,
 773                registry.clone(),
 774                prompt_builder.clone(),
 775                Arc::new(SlashCommandWorkingSet::default()),
 776                None,
 777                None,
 778                cx,
 779            )
 780        });
 781
 782        cx.update(|cx| {
 783            cx.subscribe(&context, {
 784                let network = network.clone();
 785                move |_, event, _| {
 786                    if let ContextEvent::Operation(op) = event {
 787                        network
 788                            .lock()
 789                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
 790                    }
 791                }
 792            })
 793            .detach();
 794        });
 795
 796        contexts.push(context);
 797        network.lock().add_peer(i as ReplicaId);
 798    }
 799
 800    let mut mutation_count = operations;
 801
 802    while mutation_count > 0
 803        || !network.lock().is_idle()
 804        || network.lock().contains_disconnected_peers()
 805    {
 806        let context_index = rng.gen_range(0..contexts.len());
 807        let context = &contexts[context_index];
 808
 809        match rng.gen_range(0..100) {
 810            0..=29 if mutation_count > 0 => {
 811                log::info!("Context {}: edit buffer", context_index);
 812                context.update(cx, |context, cx| {
 813                    context
 814                        .buffer
 815                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
 816                });
 817                mutation_count -= 1;
 818            }
 819            30..=44 if mutation_count > 0 => {
 820                context.update(cx, |context, cx| {
 821                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
 822                    log::info!("Context {}: split message at {:?}", context_index, range);
 823                    context.split_message(range, cx);
 824                });
 825                mutation_count -= 1;
 826            }
 827            45..=59 if mutation_count > 0 => {
 828                context.update(cx, |context, cx| {
 829                    if let Some(message) = context.messages(cx).choose(&mut rng) {
 830                        let role = *[Role::User, Role::Assistant, Role::System]
 831                            .choose(&mut rng)
 832                            .unwrap();
 833                        log::info!(
 834                            "Context {}: insert message after {:?} with {:?}",
 835                            context_index,
 836                            message.id,
 837                            role
 838                        );
 839                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
 840                    }
 841                });
 842                mutation_count -= 1;
 843            }
 844            60..=74 if mutation_count > 0 => {
 845                context.update(cx, |context, cx| {
 846                    let command_text = "/".to_string()
 847                        + slash_commands
 848                            .command_names()
 849                            .choose(&mut rng)
 850                            .unwrap()
 851                            .clone()
 852                            .as_ref();
 853
 854                    let command_range = context.buffer.update(cx, |buffer, cx| {
 855                        let offset = buffer.random_byte_range(0, &mut rng).start;
 856                        buffer.edit(
 857                            [(offset..offset, format!("\n{}\n", command_text))],
 858                            None,
 859                            cx,
 860                        );
 861                        offset + 1..offset + 1 + command_text.len()
 862                    });
 863
 864                    let output_text = RandomCharIter::new(&mut rng)
 865                        .filter(|c| *c != '\r')
 866                        .take(10)
 867                        .collect::<String>();
 868
 869                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
 870                        role: Role::User,
 871                        merge_same_roles: true,
 872                    })];
 873
 874                    let num_sections = rng.gen_range(0..=3);
 875                    let mut section_start = 0;
 876                    for _ in 0..num_sections {
 877                        let mut section_end = rng.gen_range(section_start..=output_text.len());
 878                        while !output_text.is_char_boundary(section_end) {
 879                            section_end += 1;
 880                        }
 881                        events.push(Ok(SlashCommandEvent::StartSection {
 882                            icon: IconName::Ai,
 883                            label: "section".into(),
 884                            metadata: None,
 885                        }));
 886                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 887                            text: output_text[section_start..section_end].to_string(),
 888                            run_commands_in_text: false,
 889                        })));
 890                        events.push(Ok(SlashCommandEvent::EndSection));
 891                        section_start = section_end;
 892                    }
 893
 894                    if section_start < output_text.len() {
 895                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 896                            text: output_text[section_start..].to_string(),
 897                            run_commands_in_text: false,
 898                        })));
 899                    }
 900
 901                    log::info!(
 902                        "Context {}: insert slash command output at {:?} with {:?} events",
 903                        context_index,
 904                        command_range,
 905                        events.len()
 906                    );
 907
 908                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
 909                        ..context.buffer.read(cx).anchor_after(command_range.end);
 910                    context.insert_command_output(
 911                        command_range,
 912                        "/command",
 913                        Task::ready(Ok(stream::iter(events).boxed())),
 914                        true,
 915                        cx,
 916                    );
 917                });
 918                cx.run_until_parked();
 919                mutation_count -= 1;
 920            }
 921            75..=84 if mutation_count > 0 => {
 922                context.update(cx, |context, cx| {
 923                    if let Some(message) = context.messages(cx).choose(&mut rng) {
 924                        let new_status = match rng.gen_range(0..3) {
 925                            0 => MessageStatus::Done,
 926                            1 => MessageStatus::Pending,
 927                            _ => MessageStatus::Error(SharedString::from("Random error")),
 928                        };
 929                        log::info!(
 930                            "Context {}: update message {:?} status to {:?}",
 931                            context_index,
 932                            message.id,
 933                            new_status
 934                        );
 935                        context.update_metadata(message.id, cx, |metadata| {
 936                            metadata.status = new_status;
 937                        });
 938                    }
 939                });
 940                mutation_count -= 1;
 941            }
 942            _ => {
 943                let replica_id = context_index as ReplicaId;
 944                if network.lock().is_disconnected(replica_id) {
 945                    network.lock().reconnect_peer(replica_id, 0);
 946
 947                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
 948                        let host_context = &contexts[0].read(cx);
 949                        let guest_context = context.read(cx);
 950                        (
 951                            guest_context.serialize_ops(&host_context.version(cx), cx),
 952                            host_context.serialize_ops(&guest_context.version(cx), cx),
 953                        )
 954                    });
 955                    let ops_to_send = ops_to_send.await;
 956                    let ops_to_receive = ops_to_receive
 957                        .await
 958                        .into_iter()
 959                        .map(ContextOperation::from_proto)
 960                        .collect::<Result<Vec<_>>>()
 961                        .unwrap();
 962                    log::info!(
 963                        "Context {}: reconnecting. Sent {} operations, received {} operations",
 964                        context_index,
 965                        ops_to_send.len(),
 966                        ops_to_receive.len()
 967                    );
 968
 969                    network.lock().broadcast(replica_id, ops_to_send);
 970                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
 971                } else if rng.gen_bool(0.1) && replica_id != 0 {
 972                    log::info!("Context {}: disconnecting", context_index);
 973                    network.lock().disconnect_peer(replica_id);
 974                } else if network.lock().has_unreceived(replica_id) {
 975                    log::info!("Context {}: applying operations", context_index);
 976                    let ops = network.lock().receive(replica_id);
 977                    let ops = ops
 978                        .into_iter()
 979                        .map(ContextOperation::from_proto)
 980                        .collect::<Result<Vec<_>>>()
 981                        .unwrap();
 982                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
 983                }
 984            }
 985        }
 986    }
 987
 988    cx.read(|cx| {
 989        let first_context = contexts[0].read(cx);
 990        for context in &contexts[1..] {
 991            let context = context.read(cx);
 992            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
 993            assert_eq!(
 994                context.buffer.read(cx).text(),
 995                first_context.buffer.read(cx).text(),
 996                "Context {} text != Context 0 text",
 997                context.buffer.read(cx).replica_id()
 998            );
 999            assert_eq!(
1000                context.message_anchors,
1001                first_context.message_anchors,
1002                "Context {} messages != Context 0 messages",
1003                context.buffer.read(cx).replica_id()
1004            );
1005            assert_eq!(
1006                context.messages_metadata,
1007                first_context.messages_metadata,
1008                "Context {} message metadata != Context 0 message metadata",
1009                context.buffer.read(cx).replica_id()
1010            );
1011            assert_eq!(
1012                context.slash_command_output_sections,
1013                first_context.slash_command_output_sections,
1014                "Context {} slash command output sections != Context 0 slash command output sections",
1015                context.buffer.read(cx).replica_id()
1016            );
1017        }
1018    });
1019}
1020
1021#[gpui::test]
1022fn test_mark_cache_anchors(cx: &mut App) {
1023    init_test(cx);
1024
1025    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1026    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1027    let context = cx.new(|cx| {
1028        AssistantContext::local(
1029            registry,
1030            None,
1031            None,
1032            prompt_builder.clone(),
1033            Arc::new(SlashCommandWorkingSet::default()),
1034            cx,
1035        )
1036    });
1037    let buffer = context.read(cx).buffer.clone();
1038
1039    // Create a test cache configuration
1040    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1041        max_cache_anchors: 3,
1042        should_speculate: true,
1043        min_total_token: 10,
1044    });
1045
1046    let message_1 = context.read(cx).message_anchors[0].clone();
1047
1048    context.update(cx, |context, cx| {
1049        context.mark_cache_anchors(cache_configuration, false, cx)
1050    });
1051
1052    assert_eq!(
1053        messages_cache(&context, cx)
1054            .iter()
1055            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1056            .count(),
1057        0,
1058        "Empty messages should not have any cache anchors."
1059    );
1060
1061    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1062    let message_2 = context
1063        .update(cx, |context, cx| {
1064            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1065        })
1066        .unwrap();
1067
1068    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1069    let message_3 = context
1070        .update(cx, |context, cx| {
1071            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1072        })
1073        .unwrap();
1074    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1075
1076    context.update(cx, |context, cx| {
1077        context.mark_cache_anchors(cache_configuration, false, cx)
1078    });
1079    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1080    assert_eq!(
1081        messages_cache(&context, cx)
1082            .iter()
1083            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1084            .count(),
1085        0,
1086        "Messages should not be marked for cache before going over the token minimum."
1087    );
1088    context.update(cx, |context, _| {
1089        context.token_count = Some(20);
1090    });
1091
1092    context.update(cx, |context, cx| {
1093        context.mark_cache_anchors(cache_configuration, true, cx)
1094    });
1095    assert_eq!(
1096        messages_cache(&context, cx)
1097            .iter()
1098            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1099            .collect::<Vec<bool>>(),
1100        vec![true, true, false],
1101        "Last message should not be an anchor on speculative request."
1102    );
1103
1104    context
1105        .update(cx, |context, cx| {
1106            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1107        })
1108        .unwrap();
1109
1110    context.update(cx, |context, cx| {
1111        context.mark_cache_anchors(cache_configuration, false, cx)
1112    });
1113    assert_eq!(
1114        messages_cache(&context, cx)
1115            .iter()
1116            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1117            .collect::<Vec<bool>>(),
1118        vec![false, true, true, false],
1119        "Most recent message should also be cached if not a speculative request."
1120    );
1121    context.update(cx, |context, cx| {
1122        context.update_cache_status_for_completion(cx)
1123    });
1124    assert_eq!(
1125        messages_cache(&context, cx)
1126            .iter()
1127            .map(|(_, cache)| cache
1128                .as_ref()
1129                .map_or(None, |cache| Some(cache.status.clone())))
1130            .collect::<Vec<Option<CacheStatus>>>(),
1131        vec![
1132            Some(CacheStatus::Cached),
1133            Some(CacheStatus::Cached),
1134            Some(CacheStatus::Cached),
1135            None
1136        ],
1137        "All user messages prior to anchor should be marked as cached."
1138    );
1139
1140    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1141    context.update(cx, |context, cx| {
1142        context.mark_cache_anchors(cache_configuration, false, cx)
1143    });
1144    assert_eq!(
1145        messages_cache(&context, cx)
1146            .iter()
1147            .map(|(_, cache)| cache
1148                .as_ref()
1149                .map_or(None, |cache| Some(cache.status.clone())))
1150            .collect::<Vec<Option<CacheStatus>>>(),
1151        vec![
1152            Some(CacheStatus::Cached),
1153            Some(CacheStatus::Cached),
1154            Some(CacheStatus::Pending),
1155            None
1156        ],
1157        "Modifying a message should invalidate it's cache but leave previous messages."
1158    );
1159    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1160    context.update(cx, |context, cx| {
1161        context.mark_cache_anchors(cache_configuration, false, cx)
1162    });
1163    assert_eq!(
1164        messages_cache(&context, cx)
1165            .iter()
1166            .map(|(_, cache)| cache
1167                .as_ref()
1168                .map_or(None, |cache| Some(cache.status.clone())))
1169            .collect::<Vec<Option<CacheStatus>>>(),
1170        vec![
1171            Some(CacheStatus::Pending),
1172            Some(CacheStatus::Pending),
1173            Some(CacheStatus::Pending),
1174            None
1175        ],
1176        "Modifying a message should invalidate all future messages."
1177    );
1178}
1179
1180fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1181    context
1182        .read(cx)
1183        .messages(cx)
1184        .map(|message| (message.id, message.role, message.offset_range))
1185        .collect()
1186}
1187
1188fn messages_cache(
1189    context: &Entity<AssistantContext>,
1190    cx: &App,
1191) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1192    context
1193        .read(cx)
1194        .messages(cx)
1195        .map(|message| (message.id, message.cache.clone()))
1196        .collect()
1197}
1198
1199fn init_test(cx: &mut App) {
1200    let settings_store = SettingsStore::test(cx);
1201    prompt_store::init(cx);
1202    LanguageModelRegistry::test(cx);
1203    cx.set_global(settings_store);
1204    language::init(cx);
1205    assistant_settings::init(cx);
1206    Project::init_settings(cx);
1207}
1208
1209#[derive(Clone)]
1210struct FakeSlashCommand(String);
1211
1212impl SlashCommand for FakeSlashCommand {
1213    fn name(&self) -> String {
1214        self.0.clone()
1215    }
1216
1217    fn description(&self) -> String {
1218        format!("Fake slash command: {}", self.0)
1219    }
1220
1221    fn menu_text(&self) -> String {
1222        format!("Run fake command: {}", self.0)
1223    }
1224
1225    fn complete_argument(
1226        self: Arc<Self>,
1227        _arguments: &[String],
1228        _cancel: Arc<AtomicBool>,
1229        _workspace: Option<WeakEntity<Workspace>>,
1230        _window: &mut Window,
1231        _cx: &mut App,
1232    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1233        Task::ready(Ok(vec![]))
1234    }
1235
1236    fn requires_argument(&self) -> bool {
1237        false
1238    }
1239
1240    fn run(
1241        self: Arc<Self>,
1242        _arguments: &[String],
1243        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1244        _context_buffer: BufferSnapshot,
1245        _workspace: WeakEntity<Workspace>,
1246        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1247        _window: &mut Window,
1248        _cx: &mut App,
1249    ) -> Task<SlashCommandResult> {
1250        Task::ready(Ok(SlashCommandOutput {
1251            text: format!("Executed fake command: {}", self.0),
1252            sections: vec![],
1253            run_commands_in_text: false,
1254        }
1255        .to_event_stream()))
1256    }
1257}