context_tests.rs

   1use super::{AssistantEdit, MessageCacheMetadata};
   2use crate::{
   3    assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
   4    Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
   5};
   6use anyhow::Result;
   7use assistant_slash_command::{
   8    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
   9    SlashCommandRegistry, SlashCommandResult,
  10};
  11use collections::HashSet;
  12use fs::FakeFs;
  13use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  14use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  15use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  16use parking_lot::Mutex;
  17use pretty_assertions::assert_eq;
  18use project::Project;
  19use rand::prelude::*;
  20use serde_json::json;
  21use settings::SettingsStore;
  22use std::{
  23    cell::RefCell,
  24    env,
  25    ops::Range,
  26    path::Path,
  27    rc::Rc,
  28    sync::{atomic::AtomicBool, Arc},
  29};
  30use text::{network::Network, OffsetRangeExt as _, ReplicaId};
  31use ui::{Context as _, WindowContext};
  32use unindent::Unindent;
  33use util::{
  34    test::{generate_marked_text, marked_text_ranges},
  35    RandomCharIter,
  36};
  37use workspace::Workspace;
  38
  39#[gpui::test]
  40fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  41    let settings_store = SettingsStore::test(cx);
  42    LanguageModelRegistry::test(cx);
  43    cx.set_global(settings_store);
  44    assistant_panel::init(cx);
  45    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  46    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  47    let context =
  48        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
  49    let buffer = context.read(cx).buffer.clone();
  50
  51    let message_1 = context.read(cx).message_anchors[0].clone();
  52    assert_eq!(
  53        messages(&context, cx),
  54        vec![(message_1.id, Role::User, 0..0)]
  55    );
  56
  57    let message_2 = context.update(cx, |context, cx| {
  58        context
  59            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  60            .unwrap()
  61    });
  62    assert_eq!(
  63        messages(&context, cx),
  64        vec![
  65            (message_1.id, Role::User, 0..1),
  66            (message_2.id, Role::Assistant, 1..1)
  67        ]
  68    );
  69
  70    buffer.update(cx, |buffer, cx| {
  71        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  72    });
  73    assert_eq!(
  74        messages(&context, cx),
  75        vec![
  76            (message_1.id, Role::User, 0..2),
  77            (message_2.id, Role::Assistant, 2..3)
  78        ]
  79    );
  80
  81    let message_3 = context.update(cx, |context, cx| {
  82        context
  83            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  84            .unwrap()
  85    });
  86    assert_eq!(
  87        messages(&context, cx),
  88        vec![
  89            (message_1.id, Role::User, 0..2),
  90            (message_2.id, Role::Assistant, 2..4),
  91            (message_3.id, Role::User, 4..4)
  92        ]
  93    );
  94
  95    let message_4 = context.update(cx, |context, cx| {
  96        context
  97            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  98            .unwrap()
  99    });
 100    assert_eq!(
 101        messages(&context, cx),
 102        vec![
 103            (message_1.id, Role::User, 0..2),
 104            (message_2.id, Role::Assistant, 2..4),
 105            (message_4.id, Role::User, 4..5),
 106            (message_3.id, Role::User, 5..5),
 107        ]
 108    );
 109
 110    buffer.update(cx, |buffer, cx| {
 111        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 112    });
 113    assert_eq!(
 114        messages(&context, cx),
 115        vec![
 116            (message_1.id, Role::User, 0..2),
 117            (message_2.id, Role::Assistant, 2..4),
 118            (message_4.id, Role::User, 4..6),
 119            (message_3.id, Role::User, 6..7),
 120        ]
 121    );
 122
 123    // Deleting across message boundaries merges the messages.
 124    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 125    assert_eq!(
 126        messages(&context, cx),
 127        vec![
 128            (message_1.id, Role::User, 0..3),
 129            (message_3.id, Role::User, 3..4),
 130        ]
 131    );
 132
 133    // Undoing the deletion should also undo the merge.
 134    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 135    assert_eq!(
 136        messages(&context, cx),
 137        vec![
 138            (message_1.id, Role::User, 0..2),
 139            (message_2.id, Role::Assistant, 2..4),
 140            (message_4.id, Role::User, 4..6),
 141            (message_3.id, Role::User, 6..7),
 142        ]
 143    );
 144
 145    // Redoing the deletion should also redo the merge.
 146    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 147    assert_eq!(
 148        messages(&context, cx),
 149        vec![
 150            (message_1.id, Role::User, 0..3),
 151            (message_3.id, Role::User, 3..4),
 152        ]
 153    );
 154
 155    // Ensure we can still insert after a merged message.
 156    let message_5 = context.update(cx, |context, cx| {
 157        context
 158            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 159            .unwrap()
 160    });
 161    assert_eq!(
 162        messages(&context, cx),
 163        vec![
 164            (message_1.id, Role::User, 0..3),
 165            (message_5.id, Role::System, 3..4),
 166            (message_3.id, Role::User, 4..5)
 167        ]
 168    );
 169}
 170
 171#[gpui::test]
 172fn test_message_splitting(cx: &mut AppContext) {
 173    let settings_store = SettingsStore::test(cx);
 174    cx.set_global(settings_store);
 175    LanguageModelRegistry::test(cx);
 176    assistant_panel::init(cx);
 177    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 178
 179    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 180    let context =
 181        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 182    let buffer = context.read(cx).buffer.clone();
 183
 184    let message_1 = context.read(cx).message_anchors[0].clone();
 185    assert_eq!(
 186        messages(&context, cx),
 187        vec![(message_1.id, Role::User, 0..0)]
 188    );
 189
 190    buffer.update(cx, |buffer, cx| {
 191        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 192    });
 193
 194    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 195    let message_2 = message_2.unwrap();
 196
 197    // We recycle newlines in the middle of a split message
 198    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 199    assert_eq!(
 200        messages(&context, cx),
 201        vec![
 202            (message_1.id, Role::User, 0..4),
 203            (message_2.id, Role::User, 4..16),
 204        ]
 205    );
 206
 207    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 208    let message_3 = message_3.unwrap();
 209
 210    // We don't recycle newlines at the end of a split message
 211    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 212    assert_eq!(
 213        messages(&context, cx),
 214        vec![
 215            (message_1.id, Role::User, 0..4),
 216            (message_3.id, Role::User, 4..5),
 217            (message_2.id, Role::User, 5..17),
 218        ]
 219    );
 220
 221    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 222    let message_4 = message_4.unwrap();
 223    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 224    assert_eq!(
 225        messages(&context, cx),
 226        vec![
 227            (message_1.id, Role::User, 0..4),
 228            (message_3.id, Role::User, 4..5),
 229            (message_2.id, Role::User, 5..9),
 230            (message_4.id, Role::User, 9..17),
 231        ]
 232    );
 233
 234    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 235    let message_5 = message_5.unwrap();
 236    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 237    assert_eq!(
 238        messages(&context, cx),
 239        vec![
 240            (message_1.id, Role::User, 0..4),
 241            (message_3.id, Role::User, 4..5),
 242            (message_2.id, Role::User, 5..9),
 243            (message_4.id, Role::User, 9..10),
 244            (message_5.id, Role::User, 10..18),
 245        ]
 246    );
 247
 248    let (message_6, message_7) =
 249        context.update(cx, |context, cx| context.split_message(14..16, cx));
 250    let message_6 = message_6.unwrap();
 251    let message_7 = message_7.unwrap();
 252    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 253    assert_eq!(
 254        messages(&context, cx),
 255        vec![
 256            (message_1.id, Role::User, 0..4),
 257            (message_3.id, Role::User, 4..5),
 258            (message_2.id, Role::User, 5..9),
 259            (message_4.id, Role::User, 9..10),
 260            (message_5.id, Role::User, 10..14),
 261            (message_6.id, Role::User, 14..17),
 262            (message_7.id, Role::User, 17..19),
 263        ]
 264    );
 265}
 266
 267#[gpui::test]
 268fn test_messages_for_offsets(cx: &mut AppContext) {
 269    let settings_store = SettingsStore::test(cx);
 270    LanguageModelRegistry::test(cx);
 271    cx.set_global(settings_store);
 272    assistant_panel::init(cx);
 273    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 274    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 275    let context =
 276        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 277    let buffer = context.read(cx).buffer.clone();
 278
 279    let message_1 = context.read(cx).message_anchors[0].clone();
 280    assert_eq!(
 281        messages(&context, cx),
 282        vec![(message_1.id, Role::User, 0..0)]
 283    );
 284
 285    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 286    let message_2 = context
 287        .update(cx, |context, cx| {
 288            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 289        })
 290        .unwrap();
 291    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 292
 293    let message_3 = context
 294        .update(cx, |context, cx| {
 295            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 296        })
 297        .unwrap();
 298    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 299
 300    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 301    assert_eq!(
 302        messages(&context, cx),
 303        vec![
 304            (message_1.id, Role::User, 0..4),
 305            (message_2.id, Role::User, 4..8),
 306            (message_3.id, Role::User, 8..11)
 307        ]
 308    );
 309
 310    assert_eq!(
 311        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 312        [message_1.id, message_2.id, message_3.id]
 313    );
 314    assert_eq!(
 315        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 316        [message_1.id, message_3.id]
 317    );
 318
 319    let message_4 = context
 320        .update(cx, |context, cx| {
 321            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 322        })
 323        .unwrap();
 324    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 325    assert_eq!(
 326        messages(&context, cx),
 327        vec![
 328            (message_1.id, Role::User, 0..4),
 329            (message_2.id, Role::User, 4..8),
 330            (message_3.id, Role::User, 8..12),
 331            (message_4.id, Role::User, 12..12)
 332        ]
 333    );
 334    assert_eq!(
 335        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 336        [message_1.id, message_2.id, message_3.id, message_4.id]
 337    );
 338
 339    fn message_ids_for_offsets(
 340        context: &Model<Context>,
 341        offsets: &[usize],
 342        cx: &AppContext,
 343    ) -> Vec<MessageId> {
 344        context
 345            .read(cx)
 346            .messages_for_offsets(offsets.iter().copied(), cx)
 347            .into_iter()
 348            .map(|message| message.id)
 349            .collect()
 350    }
 351}
 352
 353#[gpui::test]
 354async fn test_slash_commands(cx: &mut TestAppContext) {
 355    let settings_store = cx.update(SettingsStore::test);
 356    cx.set_global(settings_store);
 357    cx.update(LanguageModelRegistry::test);
 358    cx.update(Project::init_settings);
 359    cx.update(assistant_panel::init);
 360    let fs = FakeFs::new(cx.background_executor.clone());
 361
 362    fs.insert_tree(
 363        "/test",
 364        json!({
 365            "src": {
 366                "lib.rs": "fn one() -> usize { 1 }",
 367                "main.rs": "
 368                    use crate::one;
 369                    fn main() { one(); }
 370                ".unindent(),
 371            }
 372        }),
 373    )
 374    .await;
 375
 376    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 377    slash_command_registry.register_command(file_command::FileSlashCommand, false);
 378
 379    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 380    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 381    let context =
 382        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 383
 384    let output_ranges = Rc::new(RefCell::new(HashSet::default()));
 385    context.update(cx, |_, cx| {
 386        cx.subscribe(&context, {
 387            let ranges = output_ranges.clone();
 388            move |_, _, event, _| match event {
 389                ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
 390                    for range in removed {
 391                        ranges.borrow_mut().remove(range);
 392                    }
 393                    for command in updated {
 394                        ranges.borrow_mut().insert(command.source_range.clone());
 395                    }
 396                }
 397                _ => {}
 398            }
 399        })
 400        .detach();
 401    });
 402
 403    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 404
 405    // Insert a slash command
 406    buffer.update(cx, |buffer, cx| {
 407        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 408    });
 409    assert_text_and_output_ranges(
 410        &buffer,
 411        &output_ranges.borrow(),
 412        "
 413        «/file src/lib.rs»
 414        "
 415        .unindent()
 416        .trim_end(),
 417        cx,
 418    );
 419
 420    // Edit the argument of the slash command.
 421    buffer.update(cx, |buffer, cx| {
 422        let edit_offset = buffer.text().find("lib.rs").unwrap();
 423        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 424    });
 425    assert_text_and_output_ranges(
 426        &buffer,
 427        &output_ranges.borrow(),
 428        "
 429        «/file src/main.rs»
 430        "
 431        .unindent()
 432        .trim_end(),
 433        cx,
 434    );
 435
 436    // Edit the name of the slash command, using one that doesn't exist.
 437    buffer.update(cx, |buffer, cx| {
 438        let edit_offset = buffer.text().find("/file").unwrap();
 439        buffer.edit(
 440            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 441            None,
 442            cx,
 443        );
 444    });
 445    assert_text_and_output_ranges(
 446        &buffer,
 447        &output_ranges.borrow(),
 448        "
 449        /unknown src/main.rs
 450        "
 451        .unindent()
 452        .trim_end(),
 453        cx,
 454    );
 455
 456    #[track_caller]
 457    fn assert_text_and_output_ranges(
 458        buffer: &Model<Buffer>,
 459        ranges: &HashSet<Range<language::Anchor>>,
 460        expected_marked_text: &str,
 461        cx: &mut TestAppContext,
 462    ) {
 463        let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
 464        let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
 465            let mut ranges = ranges
 466                .iter()
 467                .map(|range| range.to_offset(buffer))
 468                .collect::<Vec<_>>();
 469            ranges.sort_by_key(|a| a.start);
 470            (buffer.text(), ranges)
 471        });
 472
 473        assert_eq!(actual_text, expected_text);
 474        assert_eq!(actual_ranges, expected_ranges);
 475    }
 476}
 477
 478#[gpui::test]
 479async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 480    cx.update(prompt_library::init);
 481    let mut settings_store = cx.update(SettingsStore::test);
 482    cx.update(|cx| {
 483        settings_store
 484            .set_user_settings(
 485                r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 486                cx,
 487            )
 488            .unwrap()
 489    });
 490    cx.set_global(settings_store);
 491    cx.update(language::init);
 492    cx.update(Project::init_settings);
 493    let fs = FakeFs::new(cx.executor());
 494    let project = Project::test(fs, [Path::new("/root")], cx).await;
 495    cx.update(LanguageModelRegistry::test);
 496
 497    cx.update(assistant_panel::init);
 498    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 499
 500    // Create a new context
 501    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 502    let context = cx.new_model(|cx| {
 503        Context::local(
 504            registry.clone(),
 505            Some(project),
 506            None,
 507            prompt_builder.clone(),
 508            cx,
 509        )
 510    });
 511
 512    // Insert an assistant message to simulate a response.
 513    let assistant_message_id = context.update(cx, |context, cx| {
 514        let user_message_id = context.messages(cx).next().unwrap().id;
 515        context
 516            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 517            .unwrap()
 518            .id
 519    });
 520
 521    // No edit tags
 522    edit(
 523        &context,
 524        "
 525
 526        «one
 527        two
 528        »",
 529        cx,
 530    );
 531    expect_patches(
 532        &context,
 533        "
 534
 535        one
 536        two
 537        ",
 538        &[],
 539        cx,
 540    );
 541
 542    // Partial edit step tag is added
 543    edit(
 544        &context,
 545        "
 546
 547        one
 548        two
 549        «
 550        <patch»",
 551        cx,
 552    );
 553    expect_patches(
 554        &context,
 555        "
 556
 557        one
 558        two
 559
 560        <patch",
 561        &[],
 562        cx,
 563    );
 564
 565    // The rest of the step tag is added. The unclosed
 566    // step is treated as incomplete.
 567    edit(
 568        &context,
 569        "
 570
 571        one
 572        two
 573
 574        <patch«>
 575        <edit>»",
 576        cx,
 577    );
 578    expect_patches(
 579        &context,
 580        "
 581
 582        one
 583        two
 584
 585        «<patch>
 586        <edit>»",
 587        &[&[]],
 588        cx,
 589    );
 590
 591    // The full patch is added
 592    edit(
 593        &context,
 594        "
 595
 596        one
 597        two
 598
 599        <patch>
 600        <edit>«
 601        <description>add a `two` function</description>
 602        <path>src/lib.rs</path>
 603        <operation>insert_after</operation>
 604        <old_text>fn one</old_text>
 605        <new_text>
 606        fn two() {}
 607        </new_text>
 608        </edit>
 609        </patch>
 610
 611        also,»",
 612        cx,
 613    );
 614    expect_patches(
 615        &context,
 616        "
 617
 618        one
 619        two
 620
 621        «<patch>
 622        <edit>
 623        <description>add a `two` function</description>
 624        <path>src/lib.rs</path>
 625        <operation>insert_after</operation>
 626        <old_text>fn one</old_text>
 627        <new_text>
 628        fn two() {}
 629        </new_text>
 630        </edit>
 631        </patch>
 632        »
 633        also,",
 634        &[&[AssistantEdit {
 635            path: "src/lib.rs".into(),
 636            kind: AssistantEditKind::InsertAfter {
 637                old_text: "fn one".into(),
 638                new_text: "fn two() {}".into(),
 639                description: Some("add a `two` function".into()),
 640            },
 641        }]],
 642        cx,
 643    );
 644
 645    // The step is manually edited.
 646    edit(
 647        &context,
 648        "
 649
 650        one
 651        two
 652
 653        <patch>
 654        <edit>
 655        <description>add a `two` function</description>
 656        <path>src/lib.rs</path>
 657        <operation>insert_after</operation>
 658        <old_text>«fn zero»</old_text>
 659        <new_text>
 660        fn two() {}
 661        </new_text>
 662        </edit>
 663        </patch>
 664
 665        also,",
 666        cx,
 667    );
 668    expect_patches(
 669        &context,
 670        "
 671
 672        one
 673        two
 674
 675        «<patch>
 676        <edit>
 677        <description>add a `two` function</description>
 678        <path>src/lib.rs</path>
 679        <operation>insert_after</operation>
 680        <old_text>fn zero</old_text>
 681        <new_text>
 682        fn two() {}
 683        </new_text>
 684        </edit>
 685        </patch>
 686        »
 687        also,",
 688        &[&[AssistantEdit {
 689            path: "src/lib.rs".into(),
 690            kind: AssistantEditKind::InsertAfter {
 691                old_text: "fn zero".into(),
 692                new_text: "fn two() {}".into(),
 693                description: Some("add a `two` function".into()),
 694            },
 695        }]],
 696        cx,
 697    );
 698
 699    // When setting the message role to User, the steps are cleared.
 700    context.update(cx, |context, cx| {
 701        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 702        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 703    });
 704    expect_patches(
 705        &context,
 706        "
 707
 708        one
 709        two
 710
 711        <patch>
 712        <edit>
 713        <description>add a `two` function</description>
 714        <path>src/lib.rs</path>
 715        <operation>insert_after</operation>
 716        <old_text>fn zero</old_text>
 717        <new_text>
 718        fn two() {}
 719        </new_text>
 720        </edit>
 721        </patch>
 722
 723        also,",
 724        &[],
 725        cx,
 726    );
 727
 728    // When setting the message role back to Assistant, the steps are reparsed.
 729    context.update(cx, |context, cx| {
 730        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 731    });
 732    expect_patches(
 733        &context,
 734        "
 735
 736        one
 737        two
 738
 739        «<patch>
 740        <edit>
 741        <description>add a `two` function</description>
 742        <path>src/lib.rs</path>
 743        <operation>insert_after</operation>
 744        <old_text>fn zero</old_text>
 745        <new_text>
 746        fn two() {}
 747        </new_text>
 748        </edit>
 749        </patch>
 750        »
 751        also,",
 752        &[&[AssistantEdit {
 753            path: "src/lib.rs".into(),
 754            kind: AssistantEditKind::InsertAfter {
 755                old_text: "fn zero".into(),
 756                new_text: "fn two() {}".into(),
 757                description: Some("add a `two` function".into()),
 758            },
 759        }]],
 760        cx,
 761    );
 762
 763    // Ensure steps are re-parsed when deserializing.
 764    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 765    let deserialized_context = cx.new_model(|cx| {
 766        Context::deserialize(
 767            serialized_context,
 768            Default::default(),
 769            registry.clone(),
 770            prompt_builder.clone(),
 771            None,
 772            None,
 773            cx,
 774        )
 775    });
 776    expect_patches(
 777        &deserialized_context,
 778        "
 779
 780        one
 781        two
 782
 783        «<patch>
 784        <edit>
 785        <description>add a `two` function</description>
 786        <path>src/lib.rs</path>
 787        <operation>insert_after</operation>
 788        <old_text>fn zero</old_text>
 789        <new_text>
 790        fn two() {}
 791        </new_text>
 792        </edit>
 793        </patch>
 794        »
 795        also,",
 796        &[&[AssistantEdit {
 797            path: "src/lib.rs".into(),
 798            kind: AssistantEditKind::InsertAfter {
 799                old_text: "fn zero".into(),
 800                new_text: "fn two() {}".into(),
 801                description: Some("add a `two` function".into()),
 802            },
 803        }]],
 804        cx,
 805    );
 806
 807    fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
 808        context.update(cx, |context, cx| {
 809            context.buffer.update(cx, |buffer, cx| {
 810                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
 811            });
 812        });
 813        cx.executor().run_until_parked();
 814    }
 815
 816    #[track_caller]
 817    fn expect_patches(
 818        context: &Model<Context>,
 819        expected_marked_text: &str,
 820        expected_suggestions: &[&[AssistantEdit]],
 821        cx: &mut TestAppContext,
 822    ) {
 823        let expected_marked_text = expected_marked_text.unindent();
 824        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
 825
 826        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
 827            context.buffer.read_with(cx, |buffer, _| {
 828                let ranges = context
 829                    .patches
 830                    .iter()
 831                    .map(|entry| entry.range.to_offset(buffer))
 832                    .collect::<Vec<_>>();
 833                (
 834                    buffer.text(),
 835                    ranges,
 836                    context
 837                        .patches
 838                        .iter()
 839                        .map(|step| step.edits.clone())
 840                        .collect::<Vec<_>>(),
 841                )
 842            })
 843        });
 844
 845        assert_eq!(buffer_text, expected_text);
 846
 847        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
 848        assert_eq!(actual_marked_text, expected_marked_text);
 849
 850        assert_eq!(
 851            patches
 852                .iter()
 853                .map(|patch| {
 854                    patch
 855                        .iter()
 856                        .map(|edit| {
 857                            let edit = edit.as_ref().unwrap();
 858                            AssistantEdit {
 859                                path: edit.path.clone(),
 860                                kind: edit.kind.clone(),
 861                            }
 862                        })
 863                        .collect::<Vec<_>>()
 864                })
 865                .collect::<Vec<_>>(),
 866            expected_suggestions
 867        );
 868    }
 869}
 870
 871#[gpui::test]
 872async fn test_serialization(cx: &mut TestAppContext) {
 873    let settings_store = cx.update(SettingsStore::test);
 874    cx.set_global(settings_store);
 875    cx.update(LanguageModelRegistry::test);
 876    cx.update(assistant_panel::init);
 877    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 878    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 879    let context =
 880        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 881    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 882    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
 883    let message_1 = context.update(cx, |context, cx| {
 884        context
 885            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 886            .unwrap()
 887    });
 888    let message_2 = context.update(cx, |context, cx| {
 889        context
 890            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 891            .unwrap()
 892    });
 893    buffer.update(cx, |buffer, cx| {
 894        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 895        buffer.finalize_last_transaction();
 896    });
 897    let _message_3 = context.update(cx, |context, cx| {
 898        context
 899            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 900            .unwrap()
 901    });
 902    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 903    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 904    assert_eq!(
 905        cx.read(|cx| messages(&context, cx)),
 906        [
 907            (message_0, Role::User, 0..2),
 908            (message_1.id, Role::Assistant, 2..6),
 909            (message_2.id, Role::System, 6..6),
 910        ]
 911    );
 912
 913    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 914    let deserialized_context = cx.new_model(|cx| {
 915        Context::deserialize(
 916            serialized_context,
 917            Default::default(),
 918            registry.clone(),
 919            prompt_builder.clone(),
 920            None,
 921            None,
 922            cx,
 923        )
 924    });
 925    let deserialized_buffer =
 926        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
 927    assert_eq!(
 928        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 929        "a\nb\nc\n"
 930    );
 931    assert_eq!(
 932        cx.read(|cx| messages(&deserialized_context, cx)),
 933        [
 934            (message_0, Role::User, 0..2),
 935            (message_1.id, Role::Assistant, 2..6),
 936            (message_2.id, Role::System, 6..6),
 937        ]
 938    );
 939}
 940
 941#[gpui::test(iterations = 100)]
 942async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 943    let min_peers = env::var("MIN_PEERS")
 944        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 945        .unwrap_or(2);
 946    let max_peers = env::var("MAX_PEERS")
 947        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 948        .unwrap_or(5);
 949    let operations = env::var("OPERATIONS")
 950        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 951        .unwrap_or(50);
 952
 953    let settings_store = cx.update(SettingsStore::test);
 954    cx.set_global(settings_store);
 955    cx.update(LanguageModelRegistry::test);
 956
 957    cx.update(assistant_panel::init);
 958    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 959    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 960    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 961    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 962
 963    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 964    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 965    let mut contexts = Vec::new();
 966
 967    let num_peers = rng.gen_range(min_peers..=max_peers);
 968    let context_id = ContextId::new();
 969    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 970    for i in 0..num_peers {
 971        let context = cx.new_model(|cx| {
 972            Context::new(
 973                context_id.clone(),
 974                i as ReplicaId,
 975                language::Capability::ReadWrite,
 976                registry.clone(),
 977                prompt_builder.clone(),
 978                None,
 979                None,
 980                cx,
 981            )
 982        });
 983
 984        cx.update(|cx| {
 985            cx.subscribe(&context, {
 986                let network = network.clone();
 987                move |_, event, _| {
 988                    if let ContextEvent::Operation(op) = event {
 989                        network
 990                            .lock()
 991                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
 992                    }
 993                }
 994            })
 995            .detach();
 996        });
 997
 998        contexts.push(context);
 999        network.lock().add_peer(i as ReplicaId);
1000    }
1001
1002    let mut mutation_count = operations;
1003
1004    while mutation_count > 0
1005        || !network.lock().is_idle()
1006        || network.lock().contains_disconnected_peers()
1007    {
1008        let context_index = rng.gen_range(0..contexts.len());
1009        let context = &contexts[context_index];
1010
1011        match rng.gen_range(0..100) {
1012            0..=29 if mutation_count > 0 => {
1013                log::info!("Context {}: edit buffer", context_index);
1014                context.update(cx, |context, cx| {
1015                    context
1016                        .buffer
1017                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1018                });
1019                mutation_count -= 1;
1020            }
1021            30..=44 if mutation_count > 0 => {
1022                context.update(cx, |context, cx| {
1023                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1024                    log::info!("Context {}: split message at {:?}", context_index, range);
1025                    context.split_message(range, cx);
1026                });
1027                mutation_count -= 1;
1028            }
1029            45..=59 if mutation_count > 0 => {
1030                context.update(cx, |context, cx| {
1031                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1032                        let role = *[Role::User, Role::Assistant, Role::System]
1033                            .choose(&mut rng)
1034                            .unwrap();
1035                        log::info!(
1036                            "Context {}: insert message after {:?} with {:?}",
1037                            context_index,
1038                            message.id,
1039                            role
1040                        );
1041                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1042                    }
1043                });
1044                mutation_count -= 1;
1045            }
1046            60..=74 if mutation_count > 0 => {
1047                context.update(cx, |context, cx| {
1048                    let command_text = "/".to_string()
1049                        + slash_commands
1050                            .command_names()
1051                            .choose(&mut rng)
1052                            .unwrap()
1053                            .clone()
1054                            .as_ref();
1055
1056                    let command_range = context.buffer.update(cx, |buffer, cx| {
1057                        let offset = buffer.random_byte_range(0, &mut rng).start;
1058                        buffer.edit(
1059                            [(offset..offset, format!("\n{}\n", command_text))],
1060                            None,
1061                            cx,
1062                        );
1063                        offset + 1..offset + 1 + command_text.len()
1064                    });
1065
1066                    let output_len = rng.gen_range(1..=10);
1067                    let output_text = RandomCharIter::new(&mut rng)
1068                        .filter(|c| *c != '\r')
1069                        .take(output_len)
1070                        .collect::<String>();
1071
1072                    let num_sections = rng.gen_range(0..=3);
1073                    let mut sections = Vec::with_capacity(num_sections);
1074                    for _ in 0..num_sections {
1075                        let section_start = rng.gen_range(0..output_len);
1076                        let section_end = rng.gen_range(section_start..=output_len);
1077                        sections.push(SlashCommandOutputSection {
1078                            range: section_start..section_end,
1079                            icon: ui::IconName::Ai,
1080                            label: "section".into(),
1081                            metadata: None,
1082                        });
1083                    }
1084
1085                    log::info!(
1086                        "Context {}: insert slash command output at {:?} with {:?}",
1087                        context_index,
1088                        command_range,
1089                        sections
1090                    );
1091
1092                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1093                        ..context.buffer.read(cx).anchor_after(command_range.end);
1094                    context.insert_command_output(
1095                        command_range,
1096                        Task::ready(Ok(SlashCommandOutput {
1097                            text: output_text,
1098                            sections,
1099                            run_commands_in_text: false,
1100                        }
1101                        .to_event_stream())),
1102                        true,
1103                        false,
1104                        cx,
1105                    );
1106                });
1107                cx.run_until_parked();
1108                mutation_count -= 1;
1109            }
1110            75..=84 if mutation_count > 0 => {
1111                context.update(cx, |context, cx| {
1112                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1113                        let new_status = match rng.gen_range(0..3) {
1114                            0 => MessageStatus::Done,
1115                            1 => MessageStatus::Pending,
1116                            _ => MessageStatus::Error(SharedString::from("Random error")),
1117                        };
1118                        log::info!(
1119                            "Context {}: update message {:?} status to {:?}",
1120                            context_index,
1121                            message.id,
1122                            new_status
1123                        );
1124                        context.update_metadata(message.id, cx, |metadata| {
1125                            metadata.status = new_status;
1126                        });
1127                    }
1128                });
1129                mutation_count -= 1;
1130            }
1131            _ => {
1132                let replica_id = context_index as ReplicaId;
1133                if network.lock().is_disconnected(replica_id) {
1134                    network.lock().reconnect_peer(replica_id, 0);
1135
1136                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1137                        let host_context = &contexts[0].read(cx);
1138                        let guest_context = context.read(cx);
1139                        (
1140                            guest_context.serialize_ops(&host_context.version(cx), cx),
1141                            host_context.serialize_ops(&guest_context.version(cx), cx),
1142                        )
1143                    });
1144                    let ops_to_send = ops_to_send.await;
1145                    let ops_to_receive = ops_to_receive
1146                        .await
1147                        .into_iter()
1148                        .map(ContextOperation::from_proto)
1149                        .collect::<Result<Vec<_>>>()
1150                        .unwrap();
1151                    log::info!(
1152                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1153                        context_index,
1154                        ops_to_send.len(),
1155                        ops_to_receive.len()
1156                    );
1157
1158                    network.lock().broadcast(replica_id, ops_to_send);
1159                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1160                } else if rng.gen_bool(0.1) && replica_id != 0 {
1161                    log::info!("Context {}: disconnecting", context_index);
1162                    network.lock().disconnect_peer(replica_id);
1163                } else if network.lock().has_unreceived(replica_id) {
1164                    log::info!("Context {}: applying operations", context_index);
1165                    let ops = network.lock().receive(replica_id);
1166                    let ops = ops
1167                        .into_iter()
1168                        .map(ContextOperation::from_proto)
1169                        .collect::<Result<Vec<_>>>()
1170                        .unwrap();
1171                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1172                }
1173            }
1174        }
1175    }
1176
1177    cx.read(|cx| {
1178        let first_context = contexts[0].read(cx);
1179        for context in &contexts[1..] {
1180            let context = context.read(cx);
1181            assert!(context.pending_ops.is_empty());
1182            assert_eq!(
1183                context.buffer.read(cx).text(),
1184                first_context.buffer.read(cx).text(),
1185                "Context {} text != Context 0 text",
1186                context.buffer.read(cx).replica_id()
1187            );
1188            assert_eq!(
1189                context.message_anchors,
1190                first_context.message_anchors,
1191                "Context {} messages != Context 0 messages",
1192                context.buffer.read(cx).replica_id()
1193            );
1194            assert_eq!(
1195                context.messages_metadata,
1196                first_context.messages_metadata,
1197                "Context {} message metadata != Context 0 message metadata",
1198                context.buffer.read(cx).replica_id()
1199            );
1200            assert_eq!(
1201                context.slash_command_output_sections,
1202                first_context.slash_command_output_sections,
1203                "Context {} slash command output sections != Context 0 slash command output sections",
1204                context.buffer.read(cx).replica_id()
1205            );
1206        }
1207    });
1208}
1209
1210#[gpui::test]
1211fn test_mark_cache_anchors(cx: &mut AppContext) {
1212    let settings_store = SettingsStore::test(cx);
1213    LanguageModelRegistry::test(cx);
1214    cx.set_global(settings_store);
1215    assistant_panel::init(cx);
1216    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1217    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1218    let context =
1219        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1220    let buffer = context.read(cx).buffer.clone();
1221
1222    // Create a test cache configuration
1223    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1224        max_cache_anchors: 3,
1225        should_speculate: true,
1226        min_total_token: 10,
1227    });
1228
1229    let message_1 = context.read(cx).message_anchors[0].clone();
1230
1231    context.update(cx, |context, cx| {
1232        context.mark_cache_anchors(cache_configuration, false, cx)
1233    });
1234
1235    assert_eq!(
1236        messages_cache(&context, cx)
1237            .iter()
1238            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1239            .count(),
1240        0,
1241        "Empty messages should not have any cache anchors."
1242    );
1243
1244    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1245    let message_2 = context
1246        .update(cx, |context, cx| {
1247            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1248        })
1249        .unwrap();
1250
1251    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1252    let message_3 = context
1253        .update(cx, |context, cx| {
1254            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1255        })
1256        .unwrap();
1257    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1258
1259    context.update(cx, |context, cx| {
1260        context.mark_cache_anchors(cache_configuration, false, cx)
1261    });
1262    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1263    assert_eq!(
1264        messages_cache(&context, cx)
1265            .iter()
1266            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1267            .count(),
1268        0,
1269        "Messages should not be marked for cache before going over the token minimum."
1270    );
1271    context.update(cx, |context, _| {
1272        context.token_count = Some(20);
1273    });
1274
1275    context.update(cx, |context, cx| {
1276        context.mark_cache_anchors(cache_configuration, true, cx)
1277    });
1278    assert_eq!(
1279        messages_cache(&context, cx)
1280            .iter()
1281            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1282            .collect::<Vec<bool>>(),
1283        vec![true, true, false],
1284        "Last message should not be an anchor on speculative request."
1285    );
1286
1287    context
1288        .update(cx, |context, cx| {
1289            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1290        })
1291        .unwrap();
1292
1293    context.update(cx, |context, cx| {
1294        context.mark_cache_anchors(cache_configuration, false, cx)
1295    });
1296    assert_eq!(
1297        messages_cache(&context, cx)
1298            .iter()
1299            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1300            .collect::<Vec<bool>>(),
1301        vec![false, true, true, false],
1302        "Most recent message should also be cached if not a speculative request."
1303    );
1304    context.update(cx, |context, cx| {
1305        context.update_cache_status_for_completion(cx)
1306    });
1307    assert_eq!(
1308        messages_cache(&context, cx)
1309            .iter()
1310            .map(|(_, cache)| cache
1311                .as_ref()
1312                .map_or(None, |cache| Some(cache.status.clone())))
1313            .collect::<Vec<Option<CacheStatus>>>(),
1314        vec![
1315            Some(CacheStatus::Cached),
1316            Some(CacheStatus::Cached),
1317            Some(CacheStatus::Cached),
1318            None
1319        ],
1320        "All user messages prior to anchor should be marked as cached."
1321    );
1322
1323    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1324    context.update(cx, |context, cx| {
1325        context.mark_cache_anchors(cache_configuration, false, cx)
1326    });
1327    assert_eq!(
1328        messages_cache(&context, cx)
1329            .iter()
1330            .map(|(_, cache)| cache
1331                .as_ref()
1332                .map_or(None, |cache| Some(cache.status.clone())))
1333            .collect::<Vec<Option<CacheStatus>>>(),
1334        vec![
1335            Some(CacheStatus::Cached),
1336            Some(CacheStatus::Cached),
1337            Some(CacheStatus::Pending),
1338            None
1339        ],
1340        "Modifying a message should invalidate it's cache but leave previous messages."
1341    );
1342    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1343    context.update(cx, |context, cx| {
1344        context.mark_cache_anchors(cache_configuration, false, cx)
1345    });
1346    assert_eq!(
1347        messages_cache(&context, cx)
1348            .iter()
1349            .map(|(_, cache)| cache
1350                .as_ref()
1351                .map_or(None, |cache| Some(cache.status.clone())))
1352            .collect::<Vec<Option<CacheStatus>>>(),
1353        vec![
1354            Some(CacheStatus::Pending),
1355            Some(CacheStatus::Pending),
1356            Some(CacheStatus::Pending),
1357            None
1358        ],
1359        "Modifying a message should invalidate all future messages."
1360    );
1361}
1362
1363fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1364    context
1365        .read(cx)
1366        .messages(cx)
1367        .map(|message| (message.id, message.role, message.offset_range))
1368        .collect()
1369}
1370
1371fn messages_cache(
1372    context: &Model<Context>,
1373    cx: &AppContext,
1374) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1375    context
1376        .read(cx)
1377        .messages(cx)
1378        .map(|message| (message.id, message.cache.clone()))
1379        .collect()
1380}
1381
1382#[derive(Clone)]
1383struct FakeSlashCommand(String);
1384
1385impl SlashCommand for FakeSlashCommand {
1386    fn name(&self) -> String {
1387        self.0.clone()
1388    }
1389
1390    fn description(&self) -> String {
1391        format!("Fake slash command: {}", self.0)
1392    }
1393
1394    fn menu_text(&self) -> String {
1395        format!("Run fake command: {}", self.0)
1396    }
1397
1398    fn complete_argument(
1399        self: Arc<Self>,
1400        _arguments: &[String],
1401        _cancel: Arc<AtomicBool>,
1402        _workspace: Option<WeakView<Workspace>>,
1403        _cx: &mut WindowContext,
1404    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1405        Task::ready(Ok(vec![]))
1406    }
1407
1408    fn requires_argument(&self) -> bool {
1409        false
1410    }
1411
1412    fn run(
1413        self: Arc<Self>,
1414        _arguments: &[String],
1415        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1416        _context_buffer: BufferSnapshot,
1417        _workspace: WeakView<Workspace>,
1418        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1419        _cx: &mut WindowContext,
1420    ) -> Task<SlashCommandResult> {
1421        Task::ready(Ok(SlashCommandOutput {
1422            text: format!("Executed fake command: {}", self.0),
1423            sections: vec![],
1424            run_commands_in_text: false,
1425        }
1426        .to_event_stream()))
1427    }
1428}