context_tests.rs

   1use super::{AssistantEdit, MessageCacheMetadata};
   2use crate::{
   3    assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
   4    Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
   5};
   6use anyhow::Result;
   7use assistant_slash_command::{
   8    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
   9    SlashCommandRegistry,
  10};
  11use collections::HashSet;
  12use fs::FakeFs;
  13use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  14use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  15use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  16use parking_lot::Mutex;
  17use pretty_assertions::assert_eq;
  18use project::Project;
  19use rand::prelude::*;
  20use serde_json::json;
  21use settings::SettingsStore;
  22use std::{
  23    cell::RefCell,
  24    env,
  25    ops::Range,
  26    path::Path,
  27    rc::Rc,
  28    sync::{atomic::AtomicBool, Arc},
  29};
  30use text::{network::Network, OffsetRangeExt as _, ReplicaId};
  31use ui::{Context as _, WindowContext};
  32use unindent::Unindent;
  33use util::{
  34    test::{generate_marked_text, marked_text_ranges},
  35    RandomCharIter,
  36};
  37use workspace::Workspace;
  38
  39#[gpui::test]
  40fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  41    let settings_store = SettingsStore::test(cx);
  42    LanguageModelRegistry::test(cx);
  43    cx.set_global(settings_store);
  44    assistant_panel::init(cx);
  45    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  46    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  47    let context =
  48        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
  49    let buffer = context.read(cx).buffer.clone();
  50
  51    let message_1 = context.read(cx).message_anchors[0].clone();
  52    assert_eq!(
  53        messages(&context, cx),
  54        vec![(message_1.id, Role::User, 0..0)]
  55    );
  56
  57    let message_2 = context.update(cx, |context, cx| {
  58        context
  59            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  60            .unwrap()
  61    });
  62    assert_eq!(
  63        messages(&context, cx),
  64        vec![
  65            (message_1.id, Role::User, 0..1),
  66            (message_2.id, Role::Assistant, 1..1)
  67        ]
  68    );
  69
  70    buffer.update(cx, |buffer, cx| {
  71        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  72    });
  73    assert_eq!(
  74        messages(&context, cx),
  75        vec![
  76            (message_1.id, Role::User, 0..2),
  77            (message_2.id, Role::Assistant, 2..3)
  78        ]
  79    );
  80
  81    let message_3 = context.update(cx, |context, cx| {
  82        context
  83            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  84            .unwrap()
  85    });
  86    assert_eq!(
  87        messages(&context, cx),
  88        vec![
  89            (message_1.id, Role::User, 0..2),
  90            (message_2.id, Role::Assistant, 2..4),
  91            (message_3.id, Role::User, 4..4)
  92        ]
  93    );
  94
  95    let message_4 = context.update(cx, |context, cx| {
  96        context
  97            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  98            .unwrap()
  99    });
 100    assert_eq!(
 101        messages(&context, cx),
 102        vec![
 103            (message_1.id, Role::User, 0..2),
 104            (message_2.id, Role::Assistant, 2..4),
 105            (message_4.id, Role::User, 4..5),
 106            (message_3.id, Role::User, 5..5),
 107        ]
 108    );
 109
 110    buffer.update(cx, |buffer, cx| {
 111        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 112    });
 113    assert_eq!(
 114        messages(&context, cx),
 115        vec![
 116            (message_1.id, Role::User, 0..2),
 117            (message_2.id, Role::Assistant, 2..4),
 118            (message_4.id, Role::User, 4..6),
 119            (message_3.id, Role::User, 6..7),
 120        ]
 121    );
 122
 123    // Deleting across message boundaries merges the messages.
 124    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 125    assert_eq!(
 126        messages(&context, cx),
 127        vec![
 128            (message_1.id, Role::User, 0..3),
 129            (message_3.id, Role::User, 3..4),
 130        ]
 131    );
 132
 133    // Undoing the deletion should also undo the merge.
 134    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 135    assert_eq!(
 136        messages(&context, cx),
 137        vec![
 138            (message_1.id, Role::User, 0..2),
 139            (message_2.id, Role::Assistant, 2..4),
 140            (message_4.id, Role::User, 4..6),
 141            (message_3.id, Role::User, 6..7),
 142        ]
 143    );
 144
 145    // Redoing the deletion should also redo the merge.
 146    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 147    assert_eq!(
 148        messages(&context, cx),
 149        vec![
 150            (message_1.id, Role::User, 0..3),
 151            (message_3.id, Role::User, 3..4),
 152        ]
 153    );
 154
 155    // Ensure we can still insert after a merged message.
 156    let message_5 = context.update(cx, |context, cx| {
 157        context
 158            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 159            .unwrap()
 160    });
 161    assert_eq!(
 162        messages(&context, cx),
 163        vec![
 164            (message_1.id, Role::User, 0..3),
 165            (message_5.id, Role::System, 3..4),
 166            (message_3.id, Role::User, 4..5)
 167        ]
 168    );
 169}
 170
 171#[gpui::test]
 172fn test_message_splitting(cx: &mut AppContext) {
 173    let settings_store = SettingsStore::test(cx);
 174    cx.set_global(settings_store);
 175    LanguageModelRegistry::test(cx);
 176    assistant_panel::init(cx);
 177    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 178
 179    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 180    let context =
 181        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 182    let buffer = context.read(cx).buffer.clone();
 183
 184    let message_1 = context.read(cx).message_anchors[0].clone();
 185    assert_eq!(
 186        messages(&context, cx),
 187        vec![(message_1.id, Role::User, 0..0)]
 188    );
 189
 190    buffer.update(cx, |buffer, cx| {
 191        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 192    });
 193
 194    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 195    let message_2 = message_2.unwrap();
 196
 197    // We recycle newlines in the middle of a split message
 198    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 199    assert_eq!(
 200        messages(&context, cx),
 201        vec![
 202            (message_1.id, Role::User, 0..4),
 203            (message_2.id, Role::User, 4..16),
 204        ]
 205    );
 206
 207    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 208    let message_3 = message_3.unwrap();
 209
 210    // We don't recycle newlines at the end of a split message
 211    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 212    assert_eq!(
 213        messages(&context, cx),
 214        vec![
 215            (message_1.id, Role::User, 0..4),
 216            (message_3.id, Role::User, 4..5),
 217            (message_2.id, Role::User, 5..17),
 218        ]
 219    );
 220
 221    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 222    let message_4 = message_4.unwrap();
 223    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 224    assert_eq!(
 225        messages(&context, cx),
 226        vec![
 227            (message_1.id, Role::User, 0..4),
 228            (message_3.id, Role::User, 4..5),
 229            (message_2.id, Role::User, 5..9),
 230            (message_4.id, Role::User, 9..17),
 231        ]
 232    );
 233
 234    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 235    let message_5 = message_5.unwrap();
 236    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 237    assert_eq!(
 238        messages(&context, cx),
 239        vec![
 240            (message_1.id, Role::User, 0..4),
 241            (message_3.id, Role::User, 4..5),
 242            (message_2.id, Role::User, 5..9),
 243            (message_4.id, Role::User, 9..10),
 244            (message_5.id, Role::User, 10..18),
 245        ]
 246    );
 247
 248    let (message_6, message_7) =
 249        context.update(cx, |context, cx| context.split_message(14..16, cx));
 250    let message_6 = message_6.unwrap();
 251    let message_7 = message_7.unwrap();
 252    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 253    assert_eq!(
 254        messages(&context, cx),
 255        vec![
 256            (message_1.id, Role::User, 0..4),
 257            (message_3.id, Role::User, 4..5),
 258            (message_2.id, Role::User, 5..9),
 259            (message_4.id, Role::User, 9..10),
 260            (message_5.id, Role::User, 10..14),
 261            (message_6.id, Role::User, 14..17),
 262            (message_7.id, Role::User, 17..19),
 263        ]
 264    );
 265}
 266
 267#[gpui::test]
 268fn test_messages_for_offsets(cx: &mut AppContext) {
 269    let settings_store = SettingsStore::test(cx);
 270    LanguageModelRegistry::test(cx);
 271    cx.set_global(settings_store);
 272    assistant_panel::init(cx);
 273    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 274    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 275    let context =
 276        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 277    let buffer = context.read(cx).buffer.clone();
 278
 279    let message_1 = context.read(cx).message_anchors[0].clone();
 280    assert_eq!(
 281        messages(&context, cx),
 282        vec![(message_1.id, Role::User, 0..0)]
 283    );
 284
 285    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 286    let message_2 = context
 287        .update(cx, |context, cx| {
 288            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 289        })
 290        .unwrap();
 291    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 292
 293    let message_3 = context
 294        .update(cx, |context, cx| {
 295            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 296        })
 297        .unwrap();
 298    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 299
 300    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 301    assert_eq!(
 302        messages(&context, cx),
 303        vec![
 304            (message_1.id, Role::User, 0..4),
 305            (message_2.id, Role::User, 4..8),
 306            (message_3.id, Role::User, 8..11)
 307        ]
 308    );
 309
 310    assert_eq!(
 311        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 312        [message_1.id, message_2.id, message_3.id]
 313    );
 314    assert_eq!(
 315        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 316        [message_1.id, message_3.id]
 317    );
 318
 319    let message_4 = context
 320        .update(cx, |context, cx| {
 321            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 322        })
 323        .unwrap();
 324    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 325    assert_eq!(
 326        messages(&context, cx),
 327        vec![
 328            (message_1.id, Role::User, 0..4),
 329            (message_2.id, Role::User, 4..8),
 330            (message_3.id, Role::User, 8..12),
 331            (message_4.id, Role::User, 12..12)
 332        ]
 333    );
 334    assert_eq!(
 335        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 336        [message_1.id, message_2.id, message_3.id, message_4.id]
 337    );
 338
 339    fn message_ids_for_offsets(
 340        context: &Model<Context>,
 341        offsets: &[usize],
 342        cx: &AppContext,
 343    ) -> Vec<MessageId> {
 344        context
 345            .read(cx)
 346            .messages_for_offsets(offsets.iter().copied(), cx)
 347            .into_iter()
 348            .map(|message| message.id)
 349            .collect()
 350    }
 351}
 352
 353#[gpui::test]
 354async fn test_slash_commands(cx: &mut TestAppContext) {
 355    let settings_store = cx.update(SettingsStore::test);
 356    cx.set_global(settings_store);
 357    cx.update(LanguageModelRegistry::test);
 358    cx.update(Project::init_settings);
 359    cx.update(assistant_panel::init);
 360    let fs = FakeFs::new(cx.background_executor.clone());
 361
 362    fs.insert_tree(
 363        "/test",
 364        json!({
 365            "src": {
 366                "lib.rs": "fn one() -> usize { 1 }",
 367                "main.rs": "
 368                    use crate::one;
 369                    fn main() { one(); }
 370                ".unindent(),
 371            }
 372        }),
 373    )
 374    .await;
 375
 376    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 377    slash_command_registry.register_command(file_command::FileSlashCommand, false);
 378
 379    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 380    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 381    let context =
 382        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 383
 384    let output_ranges = Rc::new(RefCell::new(HashSet::default()));
 385    context.update(cx, |_, cx| {
 386        cx.subscribe(&context, {
 387            let ranges = output_ranges.clone();
 388            move |_, _, event, _| match event {
 389                ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
 390                    for range in removed {
 391                        ranges.borrow_mut().remove(range);
 392                    }
 393                    for command in updated {
 394                        ranges.borrow_mut().insert(command.source_range.clone());
 395                    }
 396                }
 397                _ => {}
 398            }
 399        })
 400        .detach();
 401    });
 402
 403    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 404
 405    // Insert a slash command
 406    buffer.update(cx, |buffer, cx| {
 407        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 408    });
 409    assert_text_and_output_ranges(
 410        &buffer,
 411        &output_ranges.borrow(),
 412        "
 413        «/file src/lib.rs»
 414        "
 415        .unindent()
 416        .trim_end(),
 417        cx,
 418    );
 419
 420    // Edit the argument of the slash command.
 421    buffer.update(cx, |buffer, cx| {
 422        let edit_offset = buffer.text().find("lib.rs").unwrap();
 423        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 424    });
 425    assert_text_and_output_ranges(
 426        &buffer,
 427        &output_ranges.borrow(),
 428        "
 429        «/file src/main.rs»
 430        "
 431        .unindent()
 432        .trim_end(),
 433        cx,
 434    );
 435
 436    // Edit the name of the slash command, using one that doesn't exist.
 437    buffer.update(cx, |buffer, cx| {
 438        let edit_offset = buffer.text().find("/file").unwrap();
 439        buffer.edit(
 440            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 441            None,
 442            cx,
 443        );
 444    });
 445    assert_text_and_output_ranges(
 446        &buffer,
 447        &output_ranges.borrow(),
 448        "
 449        /unknown src/main.rs
 450        "
 451        .unindent()
 452        .trim_end(),
 453        cx,
 454    );
 455
 456    #[track_caller]
 457    fn assert_text_and_output_ranges(
 458        buffer: &Model<Buffer>,
 459        ranges: &HashSet<Range<language::Anchor>>,
 460        expected_marked_text: &str,
 461        cx: &mut TestAppContext,
 462    ) {
 463        let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
 464        let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
 465            let mut ranges = ranges
 466                .iter()
 467                .map(|range| range.to_offset(buffer))
 468                .collect::<Vec<_>>();
 469            ranges.sort_by_key(|a| a.start);
 470            (buffer.text(), ranges)
 471        });
 472
 473        assert_eq!(actual_text, expected_text);
 474        assert_eq!(actual_ranges, expected_ranges);
 475    }
 476}
 477
 478#[gpui::test]
 479async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 480    cx.update(prompt_library::init);
 481    let mut settings_store = cx.update(SettingsStore::test);
 482    cx.update(|cx| {
 483        settings_store
 484            .set_user_settings(
 485                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 486                cx,
 487            )
 488            .unwrap()
 489    });
 490    cx.set_global(settings_store);
 491    cx.update(language::init);
 492    cx.update(Project::init_settings);
 493    let fs = FakeFs::new(cx.executor());
 494    let project = Project::test(fs, [Path::new("/root")], cx).await;
 495    cx.update(LanguageModelRegistry::test);
 496
 497    cx.update(assistant_panel::init);
 498    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 499
 500    // Create a new context
 501    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 502    let context = cx.new_model(|cx| {
 503        Context::local(
 504            registry.clone(),
 505            Some(project),
 506            None,
 507            prompt_builder.clone(),
 508            cx,
 509        )
 510    });
 511
 512    // Insert an assistant message to simulate a response.
 513    let assistant_message_id = context.update(cx, |context, cx| {
 514        let user_message_id = context.messages(cx).next().unwrap().id;
 515        context
 516            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 517            .unwrap()
 518            .id
 519    });
 520
 521    // No edit tags
 522    edit(
 523        &context,
 524        "
 525
 526        «one
 527        two
 528        »",
 529        cx,
 530    );
 531    expect_patches(
 532        &context,
 533        "
 534
 535        one
 536        two
 537        ",
 538        &[],
 539        cx,
 540    );
 541
 542    // Partial edit step tag is added
 543    edit(
 544        &context,
 545        "
 546
 547        one
 548        two
 549        «
 550        <patch»",
 551        cx,
 552    );
 553    expect_patches(
 554        &context,
 555        "
 556
 557        one
 558        two
 559
 560        <patch",
 561        &[],
 562        cx,
 563    );
 564
 565    // The rest of the step tag is added. The unclosed
 566    // step is treated as incomplete.
 567    edit(
 568        &context,
 569        "
 570
 571        one
 572        two
 573
 574        <patch«>
 575        <edit>»",
 576        cx,
 577    );
 578    expect_patches(
 579        &context,
 580        "
 581
 582        one
 583        two
 584
 585        «<patch>
 586        <edit>»",
 587        &[&[]],
 588        cx,
 589    );
 590
 591    // The full patch is added
 592    edit(
 593        &context,
 594        "
 595
 596        one
 597        two
 598
 599        <patch>
 600        <edit>«
 601        <description>add a `two` function</description>
 602        <path>src/lib.rs</path>
 603        <operation>insert_after</operation>
 604        <old_text>fn one</old_text>
 605        <new_text>
 606        fn two() {}
 607        </new_text>
 608        </edit>
 609        </patch>
 610
 611        also,»",
 612        cx,
 613    );
 614    expect_patches(
 615        &context,
 616        "
 617
 618        one
 619        two
 620
 621        «<patch>
 622        <edit>
 623        <description>add a `two` function</description>
 624        <path>src/lib.rs</path>
 625        <operation>insert_after</operation>
 626        <old_text>fn one</old_text>
 627        <new_text>
 628        fn two() {}
 629        </new_text>
 630        </edit>
 631        </patch>
 632        »
 633        also,",
 634        &[&[AssistantEdit {
 635            path: "src/lib.rs".into(),
 636            kind: AssistantEditKind::InsertAfter {
 637                old_text: "fn one".into(),
 638                new_text: "fn two() {}".into(),
 639                description: "add a `two` function".into(),
 640            },
 641        }]],
 642        cx,
 643    );
 644
 645    // The step is manually edited.
 646    edit(
 647        &context,
 648        "
 649
 650        one
 651        two
 652
 653        <patch>
 654        <edit>
 655        <description>add a `two` function</description>
 656        <path>src/lib.rs</path>
 657        <operation>insert_after</operation>
 658        <old_text>«fn zero»</old_text>
 659        <new_text>
 660        fn two() {}
 661        </new_text>
 662        </edit>
 663        </patch>
 664
 665        also,",
 666        cx,
 667    );
 668    expect_patches(
 669        &context,
 670        "
 671
 672        one
 673        two
 674
 675        «<patch>
 676        <edit>
 677        <description>add a `two` function</description>
 678        <path>src/lib.rs</path>
 679        <operation>insert_after</operation>
 680        <old_text>fn zero</old_text>
 681        <new_text>
 682        fn two() {}
 683        </new_text>
 684        </edit>
 685        </patch>
 686        »
 687        also,",
 688        &[&[AssistantEdit {
 689            path: "src/lib.rs".into(),
 690            kind: AssistantEditKind::InsertAfter {
 691                old_text: "fn zero".into(),
 692                new_text: "fn two() {}".into(),
 693                description: "add a `two` function".into(),
 694            },
 695        }]],
 696        cx,
 697    );
 698
 699    // When setting the message role to User, the steps are cleared.
 700    context.update(cx, |context, cx| {
 701        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 702        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 703    });
 704    expect_patches(
 705        &context,
 706        "
 707
 708        one
 709        two
 710
 711        <patch>
 712        <edit>
 713        <description>add a `two` function</description>
 714        <path>src/lib.rs</path>
 715        <operation>insert_after</operation>
 716        <old_text>fn zero</old_text>
 717        <new_text>
 718        fn two() {}
 719        </new_text>
 720        </edit>
 721        </patch>
 722
 723        also,",
 724        &[],
 725        cx,
 726    );
 727
 728    // When setting the message role back to Assistant, the steps are reparsed.
 729    context.update(cx, |context, cx| {
 730        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 731    });
 732    expect_patches(
 733        &context,
 734        "
 735
 736        one
 737        two
 738
 739        «<patch>
 740        <edit>
 741        <description>add a `two` function</description>
 742        <path>src/lib.rs</path>
 743        <operation>insert_after</operation>
 744        <old_text>fn zero</old_text>
 745        <new_text>
 746        fn two() {}
 747        </new_text>
 748        </edit>
 749        </patch>
 750        »
 751        also,",
 752        &[&[AssistantEdit {
 753            path: "src/lib.rs".into(),
 754            kind: AssistantEditKind::InsertAfter {
 755                old_text: "fn zero".into(),
 756                new_text: "fn two() {}".into(),
 757                description: "add a `two` function".into(),
 758            },
 759        }]],
 760        cx,
 761    );
 762
 763    // Ensure steps are re-parsed when deserializing.
 764    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 765    let deserialized_context = cx.new_model(|cx| {
 766        Context::deserialize(
 767            serialized_context,
 768            Default::default(),
 769            registry.clone(),
 770            prompt_builder.clone(),
 771            None,
 772            None,
 773            cx,
 774        )
 775    });
 776    expect_patches(
 777        &deserialized_context,
 778        "
 779
 780        one
 781        two
 782
 783        «<patch>
 784        <edit>
 785        <description>add a `two` function</description>
 786        <path>src/lib.rs</path>
 787        <operation>insert_after</operation>
 788        <old_text>fn zero</old_text>
 789        <new_text>
 790        fn two() {}
 791        </new_text>
 792        </edit>
 793        </patch>
 794        »
 795        also,",
 796        &[&[AssistantEdit {
 797            path: "src/lib.rs".into(),
 798            kind: AssistantEditKind::InsertAfter {
 799                old_text: "fn zero".into(),
 800                new_text: "fn two() {}".into(),
 801                description: "add a `two` function".into(),
 802            },
 803        }]],
 804        cx,
 805    );
 806
 807    fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
 808        context.update(cx, |context, cx| {
 809            context.buffer.update(cx, |buffer, cx| {
 810                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
 811            });
 812        });
 813        cx.executor().run_until_parked();
 814    }
 815
 816    #[track_caller]
 817    fn expect_patches(
 818        context: &Model<Context>,
 819        expected_marked_text: &str,
 820        expected_suggestions: &[&[AssistantEdit]],
 821        cx: &mut TestAppContext,
 822    ) {
 823        let expected_marked_text = expected_marked_text.unindent();
 824        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
 825
 826        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
 827            context.buffer.read_with(cx, |buffer, _| {
 828                let ranges = context
 829                    .patches
 830                    .iter()
 831                    .map(|entry| entry.range.to_offset(buffer))
 832                    .collect::<Vec<_>>();
 833                (
 834                    buffer.text(),
 835                    ranges,
 836                    context
 837                        .patches
 838                        .iter()
 839                        .map(|step| step.edits.clone())
 840                        .collect::<Vec<_>>(),
 841                )
 842            })
 843        });
 844
 845        assert_eq!(buffer_text, expected_text);
 846
 847        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
 848        assert_eq!(actual_marked_text, expected_marked_text);
 849
 850        assert_eq!(
 851            patches
 852                .iter()
 853                .map(|patch| {
 854                    patch
 855                        .iter()
 856                        .map(|edit| {
 857                            let edit = edit.as_ref().unwrap();
 858                            AssistantEdit {
 859                                path: edit.path.clone(),
 860                                kind: edit.kind.clone(),
 861                            }
 862                        })
 863                        .collect::<Vec<_>>()
 864                })
 865                .collect::<Vec<_>>(),
 866            expected_suggestions
 867        );
 868    }
 869}
 870
 871#[gpui::test]
 872async fn test_serialization(cx: &mut TestAppContext) {
 873    let settings_store = cx.update(SettingsStore::test);
 874    cx.set_global(settings_store);
 875    cx.update(LanguageModelRegistry::test);
 876    cx.update(assistant_panel::init);
 877    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 878    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 879    let context =
 880        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 881    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 882    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
 883    let message_1 = context.update(cx, |context, cx| {
 884        context
 885            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 886            .unwrap()
 887    });
 888    let message_2 = context.update(cx, |context, cx| {
 889        context
 890            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 891            .unwrap()
 892    });
 893    buffer.update(cx, |buffer, cx| {
 894        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 895        buffer.finalize_last_transaction();
 896    });
 897    let _message_3 = context.update(cx, |context, cx| {
 898        context
 899            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 900            .unwrap()
 901    });
 902    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 903    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 904    assert_eq!(
 905        cx.read(|cx| messages(&context, cx)),
 906        [
 907            (message_0, Role::User, 0..2),
 908            (message_1.id, Role::Assistant, 2..6),
 909            (message_2.id, Role::System, 6..6),
 910        ]
 911    );
 912
 913    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 914    let deserialized_context = cx.new_model(|cx| {
 915        Context::deserialize(
 916            serialized_context,
 917            Default::default(),
 918            registry.clone(),
 919            prompt_builder.clone(),
 920            None,
 921            None,
 922            cx,
 923        )
 924    });
 925    let deserialized_buffer =
 926        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
 927    assert_eq!(
 928        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 929        "a\nb\nc\n"
 930    );
 931    assert_eq!(
 932        cx.read(|cx| messages(&deserialized_context, cx)),
 933        [
 934            (message_0, Role::User, 0..2),
 935            (message_1.id, Role::Assistant, 2..6),
 936            (message_2.id, Role::System, 6..6),
 937        ]
 938    );
 939}
 940
 941#[gpui::test(iterations = 100)]
 942async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 943    let min_peers = env::var("MIN_PEERS")
 944        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 945        .unwrap_or(2);
 946    let max_peers = env::var("MAX_PEERS")
 947        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 948        .unwrap_or(5);
 949    let operations = env::var("OPERATIONS")
 950        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 951        .unwrap_or(50);
 952
 953    let settings_store = cx.update(SettingsStore::test);
 954    cx.set_global(settings_store);
 955    cx.update(LanguageModelRegistry::test);
 956
 957    cx.update(assistant_panel::init);
 958    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 959    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 960    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 961    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 962
 963    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 964    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 965    let mut contexts = Vec::new();
 966
 967    let num_peers = rng.gen_range(min_peers..=max_peers);
 968    let context_id = ContextId::new();
 969    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 970    for i in 0..num_peers {
 971        let context = cx.new_model(|cx| {
 972            Context::new(
 973                context_id.clone(),
 974                i as ReplicaId,
 975                language::Capability::ReadWrite,
 976                registry.clone(),
 977                prompt_builder.clone(),
 978                None,
 979                None,
 980                cx,
 981            )
 982        });
 983
 984        cx.update(|cx| {
 985            cx.subscribe(&context, {
 986                let network = network.clone();
 987                move |_, event, _| {
 988                    if let ContextEvent::Operation(op) = event {
 989                        network
 990                            .lock()
 991                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
 992                    }
 993                }
 994            })
 995            .detach();
 996        });
 997
 998        contexts.push(context);
 999        network.lock().add_peer(i as ReplicaId);
1000    }
1001
1002    let mut mutation_count = operations;
1003
1004    while mutation_count > 0
1005        || !network.lock().is_idle()
1006        || network.lock().contains_disconnected_peers()
1007    {
1008        let context_index = rng.gen_range(0..contexts.len());
1009        let context = &contexts[context_index];
1010
1011        match rng.gen_range(0..100) {
1012            0..=29 if mutation_count > 0 => {
1013                log::info!("Context {}: edit buffer", context_index);
1014                context.update(cx, |context, cx| {
1015                    context
1016                        .buffer
1017                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1018                });
1019                mutation_count -= 1;
1020            }
1021            30..=44 if mutation_count > 0 => {
1022                context.update(cx, |context, cx| {
1023                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1024                    log::info!("Context {}: split message at {:?}", context_index, range);
1025                    context.split_message(range, cx);
1026                });
1027                mutation_count -= 1;
1028            }
1029            45..=59 if mutation_count > 0 => {
1030                context.update(cx, |context, cx| {
1031                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1032                        let role = *[Role::User, Role::Assistant, Role::System]
1033                            .choose(&mut rng)
1034                            .unwrap();
1035                        log::info!(
1036                            "Context {}: insert message after {:?} with {:?}",
1037                            context_index,
1038                            message.id,
1039                            role
1040                        );
1041                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1042                    }
1043                });
1044                mutation_count -= 1;
1045            }
1046            60..=74 if mutation_count > 0 => {
1047                context.update(cx, |context, cx| {
1048                    let command_text = "/".to_string()
1049                        + slash_commands
1050                            .command_names()
1051                            .choose(&mut rng)
1052                            .unwrap()
1053                            .clone()
1054                            .as_ref();
1055
1056                    let command_range = context.buffer.update(cx, |buffer, cx| {
1057                        let offset = buffer.random_byte_range(0, &mut rng).start;
1058                        buffer.edit(
1059                            [(offset..offset, format!("\n{}\n", command_text))],
1060                            None,
1061                            cx,
1062                        );
1063                        offset + 1..offset + 1 + command_text.len()
1064                    });
1065
1066                    let output_len = rng.gen_range(1..=10);
1067                    let output_text = RandomCharIter::new(&mut rng)
1068                        .filter(|c| *c != '\r')
1069                        .take(output_len)
1070                        .collect::<String>();
1071
1072                    let num_sections = rng.gen_range(0..=3);
1073                    let mut sections = Vec::with_capacity(num_sections);
1074                    for _ in 0..num_sections {
1075                        let section_start = rng.gen_range(0..output_len);
1076                        let section_end = rng.gen_range(section_start..=output_len);
1077                        sections.push(SlashCommandOutputSection {
1078                            range: section_start..section_end,
1079                            icon: ui::IconName::Ai,
1080                            label: "section".into(),
1081                            metadata: None,
1082                        });
1083                    }
1084
1085                    log::info!(
1086                        "Context {}: insert slash command output at {:?} with {:?}",
1087                        context_index,
1088                        command_range,
1089                        sections
1090                    );
1091
1092                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1093                        ..context.buffer.read(cx).anchor_after(command_range.end);
1094                    context.insert_command_output(
1095                        command_range,
1096                        Task::ready(Ok(SlashCommandOutput {
1097                            text: output_text,
1098                            sections,
1099                            run_commands_in_text: false,
1100                        })),
1101                        true,
1102                        false,
1103                        cx,
1104                    );
1105                });
1106                cx.run_until_parked();
1107                mutation_count -= 1;
1108            }
1109            75..=84 if mutation_count > 0 => {
1110                context.update(cx, |context, cx| {
1111                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1112                        let new_status = match rng.gen_range(0..3) {
1113                            0 => MessageStatus::Done,
1114                            1 => MessageStatus::Pending,
1115                            _ => MessageStatus::Error(SharedString::from("Random error")),
1116                        };
1117                        log::info!(
1118                            "Context {}: update message {:?} status to {:?}",
1119                            context_index,
1120                            message.id,
1121                            new_status
1122                        );
1123                        context.update_metadata(message.id, cx, |metadata| {
1124                            metadata.status = new_status;
1125                        });
1126                    }
1127                });
1128                mutation_count -= 1;
1129            }
1130            _ => {
1131                let replica_id = context_index as ReplicaId;
1132                if network.lock().is_disconnected(replica_id) {
1133                    network.lock().reconnect_peer(replica_id, 0);
1134
1135                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1136                        let host_context = &contexts[0].read(cx);
1137                        let guest_context = context.read(cx);
1138                        (
1139                            guest_context.serialize_ops(&host_context.version(cx), cx),
1140                            host_context.serialize_ops(&guest_context.version(cx), cx),
1141                        )
1142                    });
1143                    let ops_to_send = ops_to_send.await;
1144                    let ops_to_receive = ops_to_receive
1145                        .await
1146                        .into_iter()
1147                        .map(ContextOperation::from_proto)
1148                        .collect::<Result<Vec<_>>>()
1149                        .unwrap();
1150                    log::info!(
1151                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1152                        context_index,
1153                        ops_to_send.len(),
1154                        ops_to_receive.len()
1155                    );
1156
1157                    network.lock().broadcast(replica_id, ops_to_send);
1158                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1159                } else if rng.gen_bool(0.1) && replica_id != 0 {
1160                    log::info!("Context {}: disconnecting", context_index);
1161                    network.lock().disconnect_peer(replica_id);
1162                } else if network.lock().has_unreceived(replica_id) {
1163                    log::info!("Context {}: applying operations", context_index);
1164                    let ops = network.lock().receive(replica_id);
1165                    let ops = ops
1166                        .into_iter()
1167                        .map(ContextOperation::from_proto)
1168                        .collect::<Result<Vec<_>>>()
1169                        .unwrap();
1170                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1171                }
1172            }
1173        }
1174    }
1175
1176    cx.read(|cx| {
1177        let first_context = contexts[0].read(cx);
1178        for context in &contexts[1..] {
1179            let context = context.read(cx);
1180            assert!(context.pending_ops.is_empty());
1181            assert_eq!(
1182                context.buffer.read(cx).text(),
1183                first_context.buffer.read(cx).text(),
1184                "Context {} text != Context 0 text",
1185                context.buffer.read(cx).replica_id()
1186            );
1187            assert_eq!(
1188                context.message_anchors,
1189                first_context.message_anchors,
1190                "Context {} messages != Context 0 messages",
1191                context.buffer.read(cx).replica_id()
1192            );
1193            assert_eq!(
1194                context.messages_metadata,
1195                first_context.messages_metadata,
1196                "Context {} message metadata != Context 0 message metadata",
1197                context.buffer.read(cx).replica_id()
1198            );
1199            assert_eq!(
1200                context.slash_command_output_sections,
1201                first_context.slash_command_output_sections,
1202                "Context {} slash command output sections != Context 0 slash command output sections",
1203                context.buffer.read(cx).replica_id()
1204            );
1205        }
1206    });
1207}
1208
1209#[gpui::test]
1210fn test_mark_cache_anchors(cx: &mut AppContext) {
1211    let settings_store = SettingsStore::test(cx);
1212    LanguageModelRegistry::test(cx);
1213    cx.set_global(settings_store);
1214    assistant_panel::init(cx);
1215    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1216    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1217    let context =
1218        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1219    let buffer = context.read(cx).buffer.clone();
1220
1221    // Create a test cache configuration
1222    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1223        max_cache_anchors: 3,
1224        should_speculate: true,
1225        min_total_token: 10,
1226    });
1227
1228    let message_1 = context.read(cx).message_anchors[0].clone();
1229
1230    context.update(cx, |context, cx| {
1231        context.mark_cache_anchors(cache_configuration, false, cx)
1232    });
1233
1234    assert_eq!(
1235        messages_cache(&context, cx)
1236            .iter()
1237            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1238            .count(),
1239        0,
1240        "Empty messages should not have any cache anchors."
1241    );
1242
1243    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1244    let message_2 = context
1245        .update(cx, |context, cx| {
1246            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1247        })
1248        .unwrap();
1249
1250    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1251    let message_3 = context
1252        .update(cx, |context, cx| {
1253            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1254        })
1255        .unwrap();
1256    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1257
1258    context.update(cx, |context, cx| {
1259        context.mark_cache_anchors(cache_configuration, false, cx)
1260    });
1261    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1262    assert_eq!(
1263        messages_cache(&context, cx)
1264            .iter()
1265            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1266            .count(),
1267        0,
1268        "Messages should not be marked for cache before going over the token minimum."
1269    );
1270    context.update(cx, |context, _| {
1271        context.token_count = Some(20);
1272    });
1273
1274    context.update(cx, |context, cx| {
1275        context.mark_cache_anchors(cache_configuration, true, cx)
1276    });
1277    assert_eq!(
1278        messages_cache(&context, cx)
1279            .iter()
1280            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1281            .collect::<Vec<bool>>(),
1282        vec![true, true, false],
1283        "Last message should not be an anchor on speculative request."
1284    );
1285
1286    context
1287        .update(cx, |context, cx| {
1288            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1289        })
1290        .unwrap();
1291
1292    context.update(cx, |context, cx| {
1293        context.mark_cache_anchors(cache_configuration, false, cx)
1294    });
1295    assert_eq!(
1296        messages_cache(&context, cx)
1297            .iter()
1298            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1299            .collect::<Vec<bool>>(),
1300        vec![false, true, true, false],
1301        "Most recent message should also be cached if not a speculative request."
1302    );
1303    context.update(cx, |context, cx| {
1304        context.update_cache_status_for_completion(cx)
1305    });
1306    assert_eq!(
1307        messages_cache(&context, cx)
1308            .iter()
1309            .map(|(_, cache)| cache
1310                .as_ref()
1311                .map_or(None, |cache| Some(cache.status.clone())))
1312            .collect::<Vec<Option<CacheStatus>>>(),
1313        vec![
1314            Some(CacheStatus::Cached),
1315            Some(CacheStatus::Cached),
1316            Some(CacheStatus::Cached),
1317            None
1318        ],
1319        "All user messages prior to anchor should be marked as cached."
1320    );
1321
1322    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1323    context.update(cx, |context, cx| {
1324        context.mark_cache_anchors(cache_configuration, false, cx)
1325    });
1326    assert_eq!(
1327        messages_cache(&context, cx)
1328            .iter()
1329            .map(|(_, cache)| cache
1330                .as_ref()
1331                .map_or(None, |cache| Some(cache.status.clone())))
1332            .collect::<Vec<Option<CacheStatus>>>(),
1333        vec![
1334            Some(CacheStatus::Cached),
1335            Some(CacheStatus::Cached),
1336            Some(CacheStatus::Pending),
1337            None
1338        ],
1339        "Modifying a message should invalidate it's cache but leave previous messages."
1340    );
1341    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1342    context.update(cx, |context, cx| {
1343        context.mark_cache_anchors(cache_configuration, false, cx)
1344    });
1345    assert_eq!(
1346        messages_cache(&context, cx)
1347            .iter()
1348            .map(|(_, cache)| cache
1349                .as_ref()
1350                .map_or(None, |cache| Some(cache.status.clone())))
1351            .collect::<Vec<Option<CacheStatus>>>(),
1352        vec![
1353            Some(CacheStatus::Pending),
1354            Some(CacheStatus::Pending),
1355            Some(CacheStatus::Pending),
1356            None
1357        ],
1358        "Modifying a message should invalidate all future messages."
1359    );
1360}
1361
1362fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1363    context
1364        .read(cx)
1365        .messages(cx)
1366        .map(|message| (message.id, message.role, message.offset_range))
1367        .collect()
1368}
1369
1370fn messages_cache(
1371    context: &Model<Context>,
1372    cx: &AppContext,
1373) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1374    context
1375        .read(cx)
1376        .messages(cx)
1377        .map(|message| (message.id, message.cache.clone()))
1378        .collect()
1379}
1380
1381#[derive(Clone)]
1382struct FakeSlashCommand(String);
1383
1384impl SlashCommand for FakeSlashCommand {
1385    fn name(&self) -> String {
1386        self.0.clone()
1387    }
1388
1389    fn description(&self) -> String {
1390        format!("Fake slash command: {}", self.0)
1391    }
1392
1393    fn menu_text(&self) -> String {
1394        format!("Run fake command: {}", self.0)
1395    }
1396
1397    fn complete_argument(
1398        self: Arc<Self>,
1399        _arguments: &[String],
1400        _cancel: Arc<AtomicBool>,
1401        _workspace: Option<WeakView<Workspace>>,
1402        _cx: &mut WindowContext,
1403    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1404        Task::ready(Ok(vec![]))
1405    }
1406
1407    fn requires_argument(&self) -> bool {
1408        false
1409    }
1410
1411    fn run(
1412        self: Arc<Self>,
1413        _arguments: &[String],
1414        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1415        _context_buffer: BufferSnapshot,
1416        _workspace: WeakView<Workspace>,
1417        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1418        _cx: &mut WindowContext,
1419    ) -> Task<Result<SlashCommandOutput>> {
1420        Task::ready(Ok(SlashCommandOutput {
1421            text: format!("Executed fake command: {}", self.0),
1422            sections: vec![],
1423            run_commands_in_text: false,
1424        }))
1425    }
1426}