context_tests.rs

   1use super::{MessageCacheMetadata, WorkflowStepEdit};
   2use crate::{
   3    assistant_panel, prompt_library, slash_command::file_command, CacheStatus, Context,
   4    ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
   5    WorkflowStepEditKind,
   6};
   7use anyhow::Result;
   8use assistant_slash_command::{
   9    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
  10    SlashCommandRegistry,
  11};
  12use collections::HashSet;
  13use fs::FakeFs;
  14use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  15use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
  16use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  17use parking_lot::Mutex;
  18use project::Project;
  19use rand::prelude::*;
  20use serde_json::json;
  21use settings::SettingsStore;
  22use std::{
  23    cell::RefCell,
  24    env,
  25    ops::Range,
  26    path::Path,
  27    rc::Rc,
  28    sync::{atomic::AtomicBool, Arc},
  29};
  30use text::{network::Network, OffsetRangeExt as _, ReplicaId};
  31use ui::{Context as _, WindowContext};
  32use unindent::Unindent;
  33use util::{
  34    test::{generate_marked_text, marked_text_ranges},
  35    RandomCharIter,
  36};
  37use workspace::Workspace;
  38
  39#[gpui::test]
  40fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  41    let settings_store = SettingsStore::test(cx);
  42    LanguageModelRegistry::test(cx);
  43    cx.set_global(settings_store);
  44    assistant_panel::init(cx);
  45    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  46    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  47    let context =
  48        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
  49    let buffer = context.read(cx).buffer.clone();
  50
  51    let message_1 = context.read(cx).message_anchors[0].clone();
  52    assert_eq!(
  53        messages(&context, cx),
  54        vec![(message_1.id, Role::User, 0..0)]
  55    );
  56
  57    let message_2 = context.update(cx, |context, cx| {
  58        context
  59            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  60            .unwrap()
  61    });
  62    assert_eq!(
  63        messages(&context, cx),
  64        vec![
  65            (message_1.id, Role::User, 0..1),
  66            (message_2.id, Role::Assistant, 1..1)
  67        ]
  68    );
  69
  70    buffer.update(cx, |buffer, cx| {
  71        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  72    });
  73    assert_eq!(
  74        messages(&context, cx),
  75        vec![
  76            (message_1.id, Role::User, 0..2),
  77            (message_2.id, Role::Assistant, 2..3)
  78        ]
  79    );
  80
  81    let message_3 = context.update(cx, |context, cx| {
  82        context
  83            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  84            .unwrap()
  85    });
  86    assert_eq!(
  87        messages(&context, cx),
  88        vec![
  89            (message_1.id, Role::User, 0..2),
  90            (message_2.id, Role::Assistant, 2..4),
  91            (message_3.id, Role::User, 4..4)
  92        ]
  93    );
  94
  95    let message_4 = context.update(cx, |context, cx| {
  96        context
  97            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  98            .unwrap()
  99    });
 100    assert_eq!(
 101        messages(&context, cx),
 102        vec![
 103            (message_1.id, Role::User, 0..2),
 104            (message_2.id, Role::Assistant, 2..4),
 105            (message_4.id, Role::User, 4..5),
 106            (message_3.id, Role::User, 5..5),
 107        ]
 108    );
 109
 110    buffer.update(cx, |buffer, cx| {
 111        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 112    });
 113    assert_eq!(
 114        messages(&context, cx),
 115        vec![
 116            (message_1.id, Role::User, 0..2),
 117            (message_2.id, Role::Assistant, 2..4),
 118            (message_4.id, Role::User, 4..6),
 119            (message_3.id, Role::User, 6..7),
 120        ]
 121    );
 122
 123    // Deleting across message boundaries merges the messages.
 124    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 125    assert_eq!(
 126        messages(&context, cx),
 127        vec![
 128            (message_1.id, Role::User, 0..3),
 129            (message_3.id, Role::User, 3..4),
 130        ]
 131    );
 132
 133    // Undoing the deletion should also undo the merge.
 134    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 135    assert_eq!(
 136        messages(&context, cx),
 137        vec![
 138            (message_1.id, Role::User, 0..2),
 139            (message_2.id, Role::Assistant, 2..4),
 140            (message_4.id, Role::User, 4..6),
 141            (message_3.id, Role::User, 6..7),
 142        ]
 143    );
 144
 145    // Redoing the deletion should also redo the merge.
 146    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 147    assert_eq!(
 148        messages(&context, cx),
 149        vec![
 150            (message_1.id, Role::User, 0..3),
 151            (message_3.id, Role::User, 3..4),
 152        ]
 153    );
 154
 155    // Ensure we can still insert after a merged message.
 156    let message_5 = context.update(cx, |context, cx| {
 157        context
 158            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 159            .unwrap()
 160    });
 161    assert_eq!(
 162        messages(&context, cx),
 163        vec![
 164            (message_1.id, Role::User, 0..3),
 165            (message_5.id, Role::System, 3..4),
 166            (message_3.id, Role::User, 4..5)
 167        ]
 168    );
 169}
 170
 171#[gpui::test]
 172fn test_message_splitting(cx: &mut AppContext) {
 173    let settings_store = SettingsStore::test(cx);
 174    cx.set_global(settings_store);
 175    LanguageModelRegistry::test(cx);
 176    assistant_panel::init(cx);
 177    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 178
 179    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 180    let context =
 181        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 182    let buffer = context.read(cx).buffer.clone();
 183
 184    let message_1 = context.read(cx).message_anchors[0].clone();
 185    assert_eq!(
 186        messages(&context, cx),
 187        vec![(message_1.id, Role::User, 0..0)]
 188    );
 189
 190    buffer.update(cx, |buffer, cx| {
 191        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 192    });
 193
 194    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 195    let message_2 = message_2.unwrap();
 196
 197    // We recycle newlines in the middle of a split message
 198    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 199    assert_eq!(
 200        messages(&context, cx),
 201        vec![
 202            (message_1.id, Role::User, 0..4),
 203            (message_2.id, Role::User, 4..16),
 204        ]
 205    );
 206
 207    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 208    let message_3 = message_3.unwrap();
 209
 210    // We don't recycle newlines at the end of a split message
 211    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 212    assert_eq!(
 213        messages(&context, cx),
 214        vec![
 215            (message_1.id, Role::User, 0..4),
 216            (message_3.id, Role::User, 4..5),
 217            (message_2.id, Role::User, 5..17),
 218        ]
 219    );
 220
 221    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 222    let message_4 = message_4.unwrap();
 223    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 224    assert_eq!(
 225        messages(&context, cx),
 226        vec![
 227            (message_1.id, Role::User, 0..4),
 228            (message_3.id, Role::User, 4..5),
 229            (message_2.id, Role::User, 5..9),
 230            (message_4.id, Role::User, 9..17),
 231        ]
 232    );
 233
 234    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 235    let message_5 = message_5.unwrap();
 236    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 237    assert_eq!(
 238        messages(&context, cx),
 239        vec![
 240            (message_1.id, Role::User, 0..4),
 241            (message_3.id, Role::User, 4..5),
 242            (message_2.id, Role::User, 5..9),
 243            (message_4.id, Role::User, 9..10),
 244            (message_5.id, Role::User, 10..18),
 245        ]
 246    );
 247
 248    let (message_6, message_7) =
 249        context.update(cx, |context, cx| context.split_message(14..16, cx));
 250    let message_6 = message_6.unwrap();
 251    let message_7 = message_7.unwrap();
 252    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 253    assert_eq!(
 254        messages(&context, cx),
 255        vec![
 256            (message_1.id, Role::User, 0..4),
 257            (message_3.id, Role::User, 4..5),
 258            (message_2.id, Role::User, 5..9),
 259            (message_4.id, Role::User, 9..10),
 260            (message_5.id, Role::User, 10..14),
 261            (message_6.id, Role::User, 14..17),
 262            (message_7.id, Role::User, 17..19),
 263        ]
 264    );
 265}
 266
 267#[gpui::test]
 268fn test_messages_for_offsets(cx: &mut AppContext) {
 269    let settings_store = SettingsStore::test(cx);
 270    LanguageModelRegistry::test(cx);
 271    cx.set_global(settings_store);
 272    assistant_panel::init(cx);
 273    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 274    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 275    let context =
 276        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 277    let buffer = context.read(cx).buffer.clone();
 278
 279    let message_1 = context.read(cx).message_anchors[0].clone();
 280    assert_eq!(
 281        messages(&context, cx),
 282        vec![(message_1.id, Role::User, 0..0)]
 283    );
 284
 285    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 286    let message_2 = context
 287        .update(cx, |context, cx| {
 288            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 289        })
 290        .unwrap();
 291    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 292
 293    let message_3 = context
 294        .update(cx, |context, cx| {
 295            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 296        })
 297        .unwrap();
 298    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 299
 300    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 301    assert_eq!(
 302        messages(&context, cx),
 303        vec![
 304            (message_1.id, Role::User, 0..4),
 305            (message_2.id, Role::User, 4..8),
 306            (message_3.id, Role::User, 8..11)
 307        ]
 308    );
 309
 310    assert_eq!(
 311        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 312        [message_1.id, message_2.id, message_3.id]
 313    );
 314    assert_eq!(
 315        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 316        [message_1.id, message_3.id]
 317    );
 318
 319    let message_4 = context
 320        .update(cx, |context, cx| {
 321            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 322        })
 323        .unwrap();
 324    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 325    assert_eq!(
 326        messages(&context, cx),
 327        vec![
 328            (message_1.id, Role::User, 0..4),
 329            (message_2.id, Role::User, 4..8),
 330            (message_3.id, Role::User, 8..12),
 331            (message_4.id, Role::User, 12..12)
 332        ]
 333    );
 334    assert_eq!(
 335        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 336        [message_1.id, message_2.id, message_3.id, message_4.id]
 337    );
 338
 339    fn message_ids_for_offsets(
 340        context: &Model<Context>,
 341        offsets: &[usize],
 342        cx: &AppContext,
 343    ) -> Vec<MessageId> {
 344        context
 345            .read(cx)
 346            .messages_for_offsets(offsets.iter().copied(), cx)
 347            .into_iter()
 348            .map(|message| message.id)
 349            .collect()
 350    }
 351}
 352
 353#[gpui::test]
 354async fn test_slash_commands(cx: &mut TestAppContext) {
 355    let settings_store = cx.update(SettingsStore::test);
 356    cx.set_global(settings_store);
 357    cx.update(LanguageModelRegistry::test);
 358    cx.update(Project::init_settings);
 359    cx.update(assistant_panel::init);
 360    let fs = FakeFs::new(cx.background_executor.clone());
 361
 362    fs.insert_tree(
 363        "/test",
 364        json!({
 365            "src": {
 366                "lib.rs": "fn one() -> usize { 1 }",
 367                "main.rs": "
 368                    use crate::one;
 369                    fn main() { one(); }
 370                ".unindent(),
 371            }
 372        }),
 373    )
 374    .await;
 375
 376    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 377    slash_command_registry.register_command(file_command::FileSlashCommand, false);
 378
 379    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 380    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 381    let context =
 382        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 383
 384    let output_ranges = Rc::new(RefCell::new(HashSet::default()));
 385    context.update(cx, |_, cx| {
 386        cx.subscribe(&context, {
 387            let ranges = output_ranges.clone();
 388            move |_, _, event, _| match event {
 389                ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
 390                    for range in removed {
 391                        ranges.borrow_mut().remove(range);
 392                    }
 393                    for command in updated {
 394                        ranges.borrow_mut().insert(command.source_range.clone());
 395                    }
 396                }
 397                _ => {}
 398            }
 399        })
 400        .detach();
 401    });
 402
 403    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 404
 405    // Insert a slash command
 406    buffer.update(cx, |buffer, cx| {
 407        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 408    });
 409    assert_text_and_output_ranges(
 410        &buffer,
 411        &output_ranges.borrow(),
 412        "
 413        «/file src/lib.rs»
 414        "
 415        .unindent()
 416        .trim_end(),
 417        cx,
 418    );
 419
 420    // Edit the argument of the slash command.
 421    buffer.update(cx, |buffer, cx| {
 422        let edit_offset = buffer.text().find("lib.rs").unwrap();
 423        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 424    });
 425    assert_text_and_output_ranges(
 426        &buffer,
 427        &output_ranges.borrow(),
 428        "
 429        «/file src/main.rs»
 430        "
 431        .unindent()
 432        .trim_end(),
 433        cx,
 434    );
 435
 436    // Edit the name of the slash command, using one that doesn't exist.
 437    buffer.update(cx, |buffer, cx| {
 438        let edit_offset = buffer.text().find("/file").unwrap();
 439        buffer.edit(
 440            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 441            None,
 442            cx,
 443        );
 444    });
 445    assert_text_and_output_ranges(
 446        &buffer,
 447        &output_ranges.borrow(),
 448        "
 449        /unknown src/main.rs
 450        "
 451        .unindent()
 452        .trim_end(),
 453        cx,
 454    );
 455
 456    #[track_caller]
 457    fn assert_text_and_output_ranges(
 458        buffer: &Model<Buffer>,
 459        ranges: &HashSet<Range<language::Anchor>>,
 460        expected_marked_text: &str,
 461        cx: &mut TestAppContext,
 462    ) {
 463        let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
 464        let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
 465            let mut ranges = ranges
 466                .iter()
 467                .map(|range| range.to_offset(buffer))
 468                .collect::<Vec<_>>();
 469            ranges.sort_by_key(|a| a.start);
 470            (buffer.text(), ranges)
 471        });
 472
 473        assert_eq!(actual_text, expected_text);
 474        assert_eq!(actual_ranges, expected_ranges);
 475    }
 476}
 477
 478#[gpui::test]
 479async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 480    cx.update(prompt_library::init);
 481    let settings_store = cx.update(SettingsStore::test);
 482    cx.set_global(settings_store);
 483    cx.update(language::init);
 484    cx.update(Project::init_settings);
 485    let fs = FakeFs::new(cx.executor());
 486    let project = Project::test(fs, [Path::new("/root")], cx).await;
 487    cx.update(LanguageModelRegistry::test);
 488
 489    cx.update(assistant_panel::init);
 490    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 491
 492    // Create a new context
 493    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 494    let context = cx.new_model(|cx| {
 495        Context::local(
 496            registry.clone(),
 497            Some(project),
 498            None,
 499            prompt_builder.clone(),
 500            cx,
 501        )
 502    });
 503
 504    // Insert an assistant message to simulate a response.
 505    let assistant_message_id = context.update(cx, |context, cx| {
 506        let user_message_id = context.messages(cx).next().unwrap().id;
 507        context
 508            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 509            .unwrap()
 510            .id
 511    });
 512
 513    // No edit tags
 514    edit(
 515        &context,
 516        "
 517
 518        «one
 519        two
 520        »",
 521        cx,
 522    );
 523    expect_steps(
 524        &context,
 525        "
 526
 527        one
 528        two
 529        ",
 530        &[],
 531        cx,
 532    );
 533
 534    // Partial edit step tag is added
 535    edit(
 536        &context,
 537        "
 538
 539        one
 540        two
 541        «
 542        <step»",
 543        cx,
 544    );
 545    expect_steps(
 546        &context,
 547        "
 548
 549        one
 550        two
 551
 552        <step",
 553        &[],
 554        cx,
 555    );
 556
 557    // The rest of the step tag is added. The unclosed
 558    // step is treated as incomplete.
 559    edit(
 560        &context,
 561        "
 562
 563        one
 564        two
 565
 566        <step«>
 567        Add a second function
 568
 569        ```rust
 570        fn two() {}
 571        ```
 572
 573        <edit>»",
 574        cx,
 575    );
 576    expect_steps(
 577        &context,
 578        "
 579
 580        one
 581        two
 582
 583        «<step>
 584        Add a second function
 585
 586        ```rust
 587        fn two() {}
 588        ```
 589
 590        <edit>»",
 591        &[&[]],
 592        cx,
 593    );
 594
 595    // The full suggestion is added
 596    edit(
 597        &context,
 598        "
 599
 600        one
 601        two
 602
 603        <step>
 604        Add a second function
 605
 606        ```rust
 607        fn two() {}
 608        ```
 609
 610        <edit>«
 611        <path>src/lib.rs</path>
 612        <operation>insert_after</operation>
 613        <search>fn one</search>
 614        <description>add a `two` function</description>
 615        </edit>
 616        </step>
 617
 618        also,»",
 619        cx,
 620    );
 621    expect_steps(
 622        &context,
 623        "
 624
 625        one
 626        two
 627
 628        «<step>
 629        Add a second function
 630
 631        ```rust
 632        fn two() {}
 633        ```
 634
 635        <edit>
 636        <path>src/lib.rs</path>
 637        <operation>insert_after</operation>
 638        <search>fn one</search>
 639        <description>add a `two` function</description>
 640        </edit>
 641        </step>»
 642
 643        also,",
 644        &[&[WorkflowStepEdit {
 645            path: "src/lib.rs".into(),
 646            kind: WorkflowStepEditKind::InsertAfter {
 647                search: "fn one".into(),
 648                description: "add a `two` function".into(),
 649            },
 650        }]],
 651        cx,
 652    );
 653
 654    // The step is manually edited.
 655    edit(
 656        &context,
 657        "
 658
 659        one
 660        two
 661
 662        <step>
 663        Add a second function
 664
 665        ```rust
 666        fn two() {}
 667        ```
 668
 669        <edit>
 670        <path>src/lib.rs</path>
 671        <operation>insert_after</operation>
 672        <search>«fn zero»</search>
 673        <description>add a `two` function</description>
 674        </edit>
 675        </step>
 676
 677        also,",
 678        cx,
 679    );
 680    expect_steps(
 681        &context,
 682        "
 683
 684        one
 685        two
 686
 687        «<step>
 688        Add a second function
 689
 690        ```rust
 691        fn two() {}
 692        ```
 693
 694        <edit>
 695        <path>src/lib.rs</path>
 696        <operation>insert_after</operation>
 697        <search>fn zero</search>
 698        <description>add a `two` function</description>
 699        </edit>
 700        </step>»
 701
 702        also,",
 703        &[&[WorkflowStepEdit {
 704            path: "src/lib.rs".into(),
 705            kind: WorkflowStepEditKind::InsertAfter {
 706                search: "fn zero".into(),
 707                description: "add a `two` function".into(),
 708            },
 709        }]],
 710        cx,
 711    );
 712
 713    // When setting the message role to User, the steps are cleared.
 714    context.update(cx, |context, cx| {
 715        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 716        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 717    });
 718    expect_steps(
 719        &context,
 720        "
 721
 722        one
 723        two
 724
 725        <step>
 726        Add a second function
 727
 728        ```rust
 729        fn two() {}
 730        ```
 731
 732        <edit>
 733        <path>src/lib.rs</path>
 734        <operation>insert_after</operation>
 735        <search>fn zero</search>
 736        <description>add a `two` function</description>
 737        </edit>
 738        </step>
 739
 740        also,",
 741        &[],
 742        cx,
 743    );
 744
 745    // When setting the message role back to Assistant, the steps are reparsed.
 746    context.update(cx, |context, cx| {
 747        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 748    });
 749    expect_steps(
 750        &context,
 751        "
 752
 753        one
 754        two
 755
 756        «<step>
 757        Add a second function
 758
 759        ```rust
 760        fn two() {}
 761        ```
 762
 763        <edit>
 764        <path>src/lib.rs</path>
 765        <operation>insert_after</operation>
 766        <search>fn zero</search>
 767        <description>add a `two` function</description>
 768        </edit>
 769        </step>»
 770
 771        also,",
 772        &[&[WorkflowStepEdit {
 773            path: "src/lib.rs".into(),
 774            kind: WorkflowStepEditKind::InsertAfter {
 775                search: "fn zero".into(),
 776                description: "add a `two` function".into(),
 777            },
 778        }]],
 779        cx,
 780    );
 781
 782    // Ensure steps are re-parsed when deserializing.
 783    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 784    let deserialized_context = cx.new_model(|cx| {
 785        Context::deserialize(
 786            serialized_context,
 787            Default::default(),
 788            registry.clone(),
 789            prompt_builder.clone(),
 790            None,
 791            None,
 792            cx,
 793        )
 794    });
 795    expect_steps(
 796        &deserialized_context,
 797        "
 798
 799        one
 800        two
 801
 802        «<step>
 803        Add a second function
 804
 805        ```rust
 806        fn two() {}
 807        ```
 808
 809        <edit>
 810        <path>src/lib.rs</path>
 811        <operation>insert_after</operation>
 812        <search>fn zero</search>
 813        <description>add a `two` function</description>
 814        </edit>
 815        </step>»
 816
 817        also,",
 818        &[&[WorkflowStepEdit {
 819            path: "src/lib.rs".into(),
 820            kind: WorkflowStepEditKind::InsertAfter {
 821                search: "fn zero".into(),
 822                description: "add a `two` function".into(),
 823            },
 824        }]],
 825        cx,
 826    );
 827
 828    fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
 829        context.update(cx, |context, cx| {
 830            context.buffer.update(cx, |buffer, cx| {
 831                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
 832            });
 833        });
 834        cx.executor().run_until_parked();
 835    }
 836
 837    fn expect_steps(
 838        context: &Model<Context>,
 839        expected_marked_text: &str,
 840        expected_suggestions: &[&[WorkflowStepEdit]],
 841        cx: &mut TestAppContext,
 842    ) {
 843        context.update(cx, |context, cx| {
 844            let expected_marked_text = expected_marked_text.unindent();
 845            let (expected_text, expected_ranges) = marked_text_ranges(&expected_marked_text, false);
 846            context.buffer.read_with(cx, |buffer, _| {
 847                assert_eq!(buffer.text(), expected_text);
 848                let ranges = context
 849                    .workflow_steps
 850                    .iter()
 851                    .map(|entry| entry.range.to_offset(buffer))
 852                    .collect::<Vec<_>>();
 853                let marked = generate_marked_text(&expected_text, &ranges, false);
 854                assert_eq!(
 855                    marked,
 856                    expected_marked_text,
 857                    "unexpected suggestion ranges. actual: {ranges:?}, expected: {expected_ranges:?}"
 858                );
 859                let suggestions = context
 860                    .workflow_steps
 861                    .iter()
 862                    .map(|step| {
 863                        step.edits
 864                            .iter()
 865                            .map(|edit| {
 866                                let edit = edit.as_ref().unwrap();
 867                                WorkflowStepEdit {
 868                                    path: edit.path.clone(),
 869                                    kind: edit.kind.clone(),
 870                                }
 871                            })
 872                            .collect::<Vec<_>>()
 873                    })
 874                    .collect::<Vec<_>>();
 875
 876                assert_eq!(suggestions, expected_suggestions);
 877            });
 878        });
 879    }
 880}
 881
 882#[gpui::test]
 883async fn test_serialization(cx: &mut TestAppContext) {
 884    let settings_store = cx.update(SettingsStore::test);
 885    cx.set_global(settings_store);
 886    cx.update(LanguageModelRegistry::test);
 887    cx.update(assistant_panel::init);
 888    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 889    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 890    let context =
 891        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 892    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 893    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
 894    let message_1 = context.update(cx, |context, cx| {
 895        context
 896            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 897            .unwrap()
 898    });
 899    let message_2 = context.update(cx, |context, cx| {
 900        context
 901            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 902            .unwrap()
 903    });
 904    buffer.update(cx, |buffer, cx| {
 905        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 906        buffer.finalize_last_transaction();
 907    });
 908    let _message_3 = context.update(cx, |context, cx| {
 909        context
 910            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 911            .unwrap()
 912    });
 913    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 914    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 915    assert_eq!(
 916        cx.read(|cx| messages(&context, cx)),
 917        [
 918            (message_0, Role::User, 0..2),
 919            (message_1.id, Role::Assistant, 2..6),
 920            (message_2.id, Role::System, 6..6),
 921        ]
 922    );
 923
 924    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 925    let deserialized_context = cx.new_model(|cx| {
 926        Context::deserialize(
 927            serialized_context,
 928            Default::default(),
 929            registry.clone(),
 930            prompt_builder.clone(),
 931            None,
 932            None,
 933            cx,
 934        )
 935    });
 936    let deserialized_buffer =
 937        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
 938    assert_eq!(
 939        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 940        "a\nb\nc\n"
 941    );
 942    assert_eq!(
 943        cx.read(|cx| messages(&deserialized_context, cx)),
 944        [
 945            (message_0, Role::User, 0..2),
 946            (message_1.id, Role::Assistant, 2..6),
 947            (message_2.id, Role::System, 6..6),
 948        ]
 949    );
 950}
 951
 952#[gpui::test(iterations = 100)]
 953async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 954    let min_peers = env::var("MIN_PEERS")
 955        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 956        .unwrap_or(2);
 957    let max_peers = env::var("MAX_PEERS")
 958        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 959        .unwrap_or(5);
 960    let operations = env::var("OPERATIONS")
 961        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 962        .unwrap_or(50);
 963
 964    let settings_store = cx.update(SettingsStore::test);
 965    cx.set_global(settings_store);
 966    cx.update(LanguageModelRegistry::test);
 967
 968    cx.update(assistant_panel::init);
 969    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 970    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 971    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 972    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 973
 974    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 975    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 976    let mut contexts = Vec::new();
 977
 978    let num_peers = rng.gen_range(min_peers..=max_peers);
 979    let context_id = ContextId::new();
 980    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 981    for i in 0..num_peers {
 982        let context = cx.new_model(|cx| {
 983            Context::new(
 984                context_id.clone(),
 985                i as ReplicaId,
 986                language::Capability::ReadWrite,
 987                registry.clone(),
 988                prompt_builder.clone(),
 989                None,
 990                None,
 991                cx,
 992            )
 993        });
 994
 995        cx.update(|cx| {
 996            cx.subscribe(&context, {
 997                let network = network.clone();
 998                move |_, event, _| {
 999                    if let ContextEvent::Operation(op) = event {
1000                        network
1001                            .lock()
1002                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1003                    }
1004                }
1005            })
1006            .detach();
1007        });
1008
1009        contexts.push(context);
1010        network.lock().add_peer(i as ReplicaId);
1011    }
1012
1013    let mut mutation_count = operations;
1014
1015    while mutation_count > 0
1016        || !network.lock().is_idle()
1017        || network.lock().contains_disconnected_peers()
1018    {
1019        let context_index = rng.gen_range(0..contexts.len());
1020        let context = &contexts[context_index];
1021
1022        match rng.gen_range(0..100) {
1023            0..=29 if mutation_count > 0 => {
1024                log::info!("Context {}: edit buffer", context_index);
1025                context.update(cx, |context, cx| {
1026                    context
1027                        .buffer
1028                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1029                });
1030                mutation_count -= 1;
1031            }
1032            30..=44 if mutation_count > 0 => {
1033                context.update(cx, |context, cx| {
1034                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1035                    log::info!("Context {}: split message at {:?}", context_index, range);
1036                    context.split_message(range, cx);
1037                });
1038                mutation_count -= 1;
1039            }
1040            45..=59 if mutation_count > 0 => {
1041                context.update(cx, |context, cx| {
1042                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1043                        let role = *[Role::User, Role::Assistant, Role::System]
1044                            .choose(&mut rng)
1045                            .unwrap();
1046                        log::info!(
1047                            "Context {}: insert message after {:?} with {:?}",
1048                            context_index,
1049                            message.id,
1050                            role
1051                        );
1052                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1053                    }
1054                });
1055                mutation_count -= 1;
1056            }
1057            60..=74 if mutation_count > 0 => {
1058                context.update(cx, |context, cx| {
1059                    let command_text = "/".to_string()
1060                        + slash_commands
1061                            .command_names()
1062                            .choose(&mut rng)
1063                            .unwrap()
1064                            .clone()
1065                            .as_ref();
1066
1067                    let command_range = context.buffer.update(cx, |buffer, cx| {
1068                        let offset = buffer.random_byte_range(0, &mut rng).start;
1069                        buffer.edit(
1070                            [(offset..offset, format!("\n{}\n", command_text))],
1071                            None,
1072                            cx,
1073                        );
1074                        offset + 1..offset + 1 + command_text.len()
1075                    });
1076
1077                    let output_len = rng.gen_range(1..=10);
1078                    let output_text = RandomCharIter::new(&mut rng)
1079                        .filter(|c| *c != '\r')
1080                        .take(output_len)
1081                        .collect::<String>();
1082
1083                    let num_sections = rng.gen_range(0..=3);
1084                    let mut sections = Vec::with_capacity(num_sections);
1085                    for _ in 0..num_sections {
1086                        let section_start = rng.gen_range(0..output_len);
1087                        let section_end = rng.gen_range(section_start..=output_len);
1088                        sections.push(SlashCommandOutputSection {
1089                            range: section_start..section_end,
1090                            icon: ui::IconName::Ai,
1091                            label: "section".into(),
1092                        });
1093                    }
1094
1095                    log::info!(
1096                        "Context {}: insert slash command output at {:?} with {:?}",
1097                        context_index,
1098                        command_range,
1099                        sections
1100                    );
1101
1102                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1103                        ..context.buffer.read(cx).anchor_after(command_range.end);
1104                    context.insert_command_output(
1105                        command_range,
1106                        Task::ready(Ok(SlashCommandOutput {
1107                            text: output_text,
1108                            sections,
1109                            run_commands_in_text: false,
1110                        })),
1111                        true,
1112                        false,
1113                        cx,
1114                    );
1115                });
1116                cx.run_until_parked();
1117                mutation_count -= 1;
1118            }
1119            75..=84 if mutation_count > 0 => {
1120                context.update(cx, |context, cx| {
1121                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1122                        let new_status = match rng.gen_range(0..3) {
1123                            0 => MessageStatus::Done,
1124                            1 => MessageStatus::Pending,
1125                            _ => MessageStatus::Error(SharedString::from("Random error")),
1126                        };
1127                        log::info!(
1128                            "Context {}: update message {:?} status to {:?}",
1129                            context_index,
1130                            message.id,
1131                            new_status
1132                        );
1133                        context.update_metadata(message.id, cx, |metadata| {
1134                            metadata.status = new_status;
1135                        });
1136                    }
1137                });
1138                mutation_count -= 1;
1139            }
1140            _ => {
1141                let replica_id = context_index as ReplicaId;
1142                if network.lock().is_disconnected(replica_id) {
1143                    network.lock().reconnect_peer(replica_id, 0);
1144
1145                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1146                        let host_context = &contexts[0].read(cx);
1147                        let guest_context = context.read(cx);
1148                        (
1149                            guest_context.serialize_ops(&host_context.version(cx), cx),
1150                            host_context.serialize_ops(&guest_context.version(cx), cx),
1151                        )
1152                    });
1153                    let ops_to_send = ops_to_send.await;
1154                    let ops_to_receive = ops_to_receive
1155                        .await
1156                        .into_iter()
1157                        .map(ContextOperation::from_proto)
1158                        .collect::<Result<Vec<_>>>()
1159                        .unwrap();
1160                    log::info!(
1161                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1162                        context_index,
1163                        ops_to_send.len(),
1164                        ops_to_receive.len()
1165                    );
1166
1167                    network.lock().broadcast(replica_id, ops_to_send);
1168                    context
1169                        .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
1170                        .unwrap();
1171                } else if rng.gen_bool(0.1) && replica_id != 0 {
1172                    log::info!("Context {}: disconnecting", context_index);
1173                    network.lock().disconnect_peer(replica_id);
1174                } else if network.lock().has_unreceived(replica_id) {
1175                    log::info!("Context {}: applying operations", context_index);
1176                    let ops = network.lock().receive(replica_id);
1177                    let ops = ops
1178                        .into_iter()
1179                        .map(ContextOperation::from_proto)
1180                        .collect::<Result<Vec<_>>>()
1181                        .unwrap();
1182                    context
1183                        .update(cx, |context, cx| context.apply_ops(ops, cx))
1184                        .unwrap();
1185                }
1186            }
1187        }
1188    }
1189
1190    cx.read(|cx| {
1191        let first_context = contexts[0].read(cx);
1192        for context in &contexts[1..] {
1193            let context = context.read(cx);
1194            assert!(context.pending_ops.is_empty());
1195            assert_eq!(
1196                context.buffer.read(cx).text(),
1197                first_context.buffer.read(cx).text(),
1198                "Context {} text != Context 0 text",
1199                context.buffer.read(cx).replica_id()
1200            );
1201            assert_eq!(
1202                context.message_anchors,
1203                first_context.message_anchors,
1204                "Context {} messages != Context 0 messages",
1205                context.buffer.read(cx).replica_id()
1206            );
1207            assert_eq!(
1208                context.messages_metadata,
1209                first_context.messages_metadata,
1210                "Context {} message metadata != Context 0 message metadata",
1211                context.buffer.read(cx).replica_id()
1212            );
1213            assert_eq!(
1214                context.slash_command_output_sections,
1215                first_context.slash_command_output_sections,
1216                "Context {} slash command output sections != Context 0 slash command output sections",
1217                context.buffer.read(cx).replica_id()
1218            );
1219        }
1220    });
1221}
1222
1223#[gpui::test]
1224fn test_mark_cache_anchors(cx: &mut AppContext) {
1225    let settings_store = SettingsStore::test(cx);
1226    LanguageModelRegistry::test(cx);
1227    cx.set_global(settings_store);
1228    assistant_panel::init(cx);
1229    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1230    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1231    let context =
1232        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1233    let buffer = context.read(cx).buffer.clone();
1234
1235    // Create a test cache configuration
1236    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1237        max_cache_anchors: 3,
1238        should_speculate: true,
1239        min_total_token: 10,
1240    });
1241
1242    let message_1 = context.read(cx).message_anchors[0].clone();
1243
1244    context.update(cx, |context, cx| {
1245        context.mark_cache_anchors(cache_configuration, false, cx)
1246    });
1247
1248    assert_eq!(
1249        messages_cache(&context, cx)
1250            .iter()
1251            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1252            .count(),
1253        0,
1254        "Empty messages should not have any cache anchors."
1255    );
1256
1257    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1258    let message_2 = context
1259        .update(cx, |context, cx| {
1260            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1261        })
1262        .unwrap();
1263
1264    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1265    let message_3 = context
1266        .update(cx, |context, cx| {
1267            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1268        })
1269        .unwrap();
1270    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1271
1272    context.update(cx, |context, cx| {
1273        context.mark_cache_anchors(cache_configuration, false, cx)
1274    });
1275    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1276    assert_eq!(
1277        messages_cache(&context, cx)
1278            .iter()
1279            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1280            .count(),
1281        0,
1282        "Messages should not be marked for cache before going over the token minimum."
1283    );
1284    context.update(cx, |context, _| {
1285        context.token_count = Some(20);
1286    });
1287
1288    context.update(cx, |context, cx| {
1289        context.mark_cache_anchors(cache_configuration, true, cx)
1290    });
1291    assert_eq!(
1292        messages_cache(&context, cx)
1293            .iter()
1294            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1295            .collect::<Vec<bool>>(),
1296        vec![true, true, false],
1297        "Last message should not be an anchor on speculative request."
1298    );
1299
1300    context
1301        .update(cx, |context, cx| {
1302            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1303        })
1304        .unwrap();
1305
1306    context.update(cx, |context, cx| {
1307        context.mark_cache_anchors(cache_configuration, false, cx)
1308    });
1309    assert_eq!(
1310        messages_cache(&context, cx)
1311            .iter()
1312            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1313            .collect::<Vec<bool>>(),
1314        vec![false, true, true, false],
1315        "Most recent message should also be cached if not a speculative request."
1316    );
1317    context.update(cx, |context, cx| {
1318        context.update_cache_status_for_completion(cx)
1319    });
1320    assert_eq!(
1321        messages_cache(&context, cx)
1322            .iter()
1323            .map(|(_, cache)| cache
1324                .as_ref()
1325                .map_or(None, |cache| Some(cache.status.clone())))
1326            .collect::<Vec<Option<CacheStatus>>>(),
1327        vec![
1328            Some(CacheStatus::Cached),
1329            Some(CacheStatus::Cached),
1330            Some(CacheStatus::Cached),
1331            None
1332        ],
1333        "All user messages prior to anchor should be marked as cached."
1334    );
1335
1336    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1337    context.update(cx, |context, cx| {
1338        context.mark_cache_anchors(cache_configuration, false, cx)
1339    });
1340    assert_eq!(
1341        messages_cache(&context, cx)
1342            .iter()
1343            .map(|(_, cache)| cache
1344                .as_ref()
1345                .map_or(None, |cache| Some(cache.status.clone())))
1346            .collect::<Vec<Option<CacheStatus>>>(),
1347        vec![
1348            Some(CacheStatus::Cached),
1349            Some(CacheStatus::Cached),
1350            Some(CacheStatus::Pending),
1351            None
1352        ],
1353        "Modifying a message should invalidate it's cache but leave previous messages."
1354    );
1355    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1356    context.update(cx, |context, cx| {
1357        context.mark_cache_anchors(cache_configuration, false, cx)
1358    });
1359    assert_eq!(
1360        messages_cache(&context, cx)
1361            .iter()
1362            .map(|(_, cache)| cache
1363                .as_ref()
1364                .map_or(None, |cache| Some(cache.status.clone())))
1365            .collect::<Vec<Option<CacheStatus>>>(),
1366        vec![
1367            Some(CacheStatus::Pending),
1368            Some(CacheStatus::Pending),
1369            Some(CacheStatus::Pending),
1370            None
1371        ],
1372        "Modifying a message should invalidate all future messages."
1373    );
1374}
1375
1376fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1377    context
1378        .read(cx)
1379        .messages(cx)
1380        .map(|message| (message.id, message.role, message.offset_range))
1381        .collect()
1382}
1383
1384fn messages_cache(
1385    context: &Model<Context>,
1386    cx: &AppContext,
1387) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1388    context
1389        .read(cx)
1390        .messages(cx)
1391        .map(|message| (message.id, message.cache.clone()))
1392        .collect()
1393}
1394
1395#[derive(Clone)]
1396struct FakeSlashCommand(String);
1397
1398impl SlashCommand for FakeSlashCommand {
1399    fn name(&self) -> String {
1400        self.0.clone()
1401    }
1402
1403    fn description(&self) -> String {
1404        format!("Fake slash command: {}", self.0)
1405    }
1406
1407    fn menu_text(&self) -> String {
1408        format!("Run fake command: {}", self.0)
1409    }
1410
1411    fn complete_argument(
1412        self: Arc<Self>,
1413        _arguments: &[String],
1414        _cancel: Arc<AtomicBool>,
1415        _workspace: Option<WeakView<Workspace>>,
1416        _cx: &mut WindowContext,
1417    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1418        Task::ready(Ok(vec![]))
1419    }
1420
1421    fn requires_argument(&self) -> bool {
1422        false
1423    }
1424
1425    fn run(
1426        self: Arc<Self>,
1427        _arguments: &[String],
1428        _workspace: WeakView<Workspace>,
1429        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1430        _cx: &mut WindowContext,
1431    ) -> Task<Result<SlashCommandOutput>> {
1432        Task::ready(Ok(SlashCommandOutput {
1433            text: format!("Executed fake command: {}", self.0),
1434            sections: vec![],
1435            run_commands_in_text: false,
1436        }))
1437    }
1438}