context_tests.rs

   1use crate::{
   2    assistant_panel, prompt_library, slash_command::file_command, workflow::tool, CacheStatus,
   3    Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
   8    SlashCommandRegistry,
   9};
  10use collections::HashSet;
  11use fs::{FakeFs, Fs as _};
  12use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
  13use indoc::indoc;
  14use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
  15use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
  16use parking_lot::Mutex;
  17use project::Project;
  18use rand::prelude::*;
  19use rope::Point;
  20use serde_json::json;
  21use settings::SettingsStore;
  22use std::{
  23    cell::RefCell,
  24    env,
  25    ops::Range,
  26    path::Path,
  27    rc::Rc,
  28    sync::{atomic::AtomicBool, Arc},
  29};
  30use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToPoint as _};
  31use ui::{Context as _, WindowContext};
  32use unindent::Unindent;
  33use util::{test::marked_text_ranges, RandomCharIter};
  34use workspace::Workspace;
  35
  36use super::MessageCacheMetadata;
  37
  38#[gpui::test]
  39fn test_inserting_and_removing_messages(cx: &mut AppContext) {
  40    let settings_store = SettingsStore::test(cx);
  41    LanguageModelRegistry::test(cx);
  42    cx.set_global(settings_store);
  43    assistant_panel::init(cx);
  44    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  45    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  46    let context =
  47        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
  48    let buffer = context.read(cx).buffer.clone();
  49
  50    let message_1 = context.read(cx).message_anchors[0].clone();
  51    assert_eq!(
  52        messages(&context, cx),
  53        vec![(message_1.id, Role::User, 0..0)]
  54    );
  55
  56    let message_2 = context.update(cx, |context, cx| {
  57        context
  58            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  59            .unwrap()
  60    });
  61    assert_eq!(
  62        messages(&context, cx),
  63        vec![
  64            (message_1.id, Role::User, 0..1),
  65            (message_2.id, Role::Assistant, 1..1)
  66        ]
  67    );
  68
  69    buffer.update(cx, |buffer, cx| {
  70        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  71    });
  72    assert_eq!(
  73        messages(&context, cx),
  74        vec![
  75            (message_1.id, Role::User, 0..2),
  76            (message_2.id, Role::Assistant, 2..3)
  77        ]
  78    );
  79
  80    let message_3 = context.update(cx, |context, cx| {
  81        context
  82            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  83            .unwrap()
  84    });
  85    assert_eq!(
  86        messages(&context, cx),
  87        vec![
  88            (message_1.id, Role::User, 0..2),
  89            (message_2.id, Role::Assistant, 2..4),
  90            (message_3.id, Role::User, 4..4)
  91        ]
  92    );
  93
  94    let message_4 = context.update(cx, |context, cx| {
  95        context
  96            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  97            .unwrap()
  98    });
  99    assert_eq!(
 100        messages(&context, cx),
 101        vec![
 102            (message_1.id, Role::User, 0..2),
 103            (message_2.id, Role::Assistant, 2..4),
 104            (message_4.id, Role::User, 4..5),
 105            (message_3.id, Role::User, 5..5),
 106        ]
 107    );
 108
 109    buffer.update(cx, |buffer, cx| {
 110        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 111    });
 112    assert_eq!(
 113        messages(&context, cx),
 114        vec![
 115            (message_1.id, Role::User, 0..2),
 116            (message_2.id, Role::Assistant, 2..4),
 117            (message_4.id, Role::User, 4..6),
 118            (message_3.id, Role::User, 6..7),
 119        ]
 120    );
 121
 122    // Deleting across message boundaries merges the messages.
 123    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 124    assert_eq!(
 125        messages(&context, cx),
 126        vec![
 127            (message_1.id, Role::User, 0..3),
 128            (message_3.id, Role::User, 3..4),
 129        ]
 130    );
 131
 132    // Undoing the deletion should also undo the merge.
 133    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 134    assert_eq!(
 135        messages(&context, cx),
 136        vec![
 137            (message_1.id, Role::User, 0..2),
 138            (message_2.id, Role::Assistant, 2..4),
 139            (message_4.id, Role::User, 4..6),
 140            (message_3.id, Role::User, 6..7),
 141        ]
 142    );
 143
 144    // Redoing the deletion should also redo the merge.
 145    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 146    assert_eq!(
 147        messages(&context, cx),
 148        vec![
 149            (message_1.id, Role::User, 0..3),
 150            (message_3.id, Role::User, 3..4),
 151        ]
 152    );
 153
 154    // Ensure we can still insert after a merged message.
 155    let message_5 = context.update(cx, |context, cx| {
 156        context
 157            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 158            .unwrap()
 159    });
 160    assert_eq!(
 161        messages(&context, cx),
 162        vec![
 163            (message_1.id, Role::User, 0..3),
 164            (message_5.id, Role::System, 3..4),
 165            (message_3.id, Role::User, 4..5)
 166        ]
 167    );
 168}
 169
 170#[gpui::test]
 171fn test_message_splitting(cx: &mut AppContext) {
 172    let settings_store = SettingsStore::test(cx);
 173    cx.set_global(settings_store);
 174    LanguageModelRegistry::test(cx);
 175    assistant_panel::init(cx);
 176    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 177
 178    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 179    let context =
 180        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 181    let buffer = context.read(cx).buffer.clone();
 182
 183    let message_1 = context.read(cx).message_anchors[0].clone();
 184    assert_eq!(
 185        messages(&context, cx),
 186        vec![(message_1.id, Role::User, 0..0)]
 187    );
 188
 189    buffer.update(cx, |buffer, cx| {
 190        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 191    });
 192
 193    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 194    let message_2 = message_2.unwrap();
 195
 196    // We recycle newlines in the middle of a split message
 197    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 198    assert_eq!(
 199        messages(&context, cx),
 200        vec![
 201            (message_1.id, Role::User, 0..4),
 202            (message_2.id, Role::User, 4..16),
 203        ]
 204    );
 205
 206    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 207    let message_3 = message_3.unwrap();
 208
 209    // We don't recycle newlines at the end of a split message
 210    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 211    assert_eq!(
 212        messages(&context, cx),
 213        vec![
 214            (message_1.id, Role::User, 0..4),
 215            (message_3.id, Role::User, 4..5),
 216            (message_2.id, Role::User, 5..17),
 217        ]
 218    );
 219
 220    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 221    let message_4 = message_4.unwrap();
 222    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 223    assert_eq!(
 224        messages(&context, cx),
 225        vec![
 226            (message_1.id, Role::User, 0..4),
 227            (message_3.id, Role::User, 4..5),
 228            (message_2.id, Role::User, 5..9),
 229            (message_4.id, Role::User, 9..17),
 230        ]
 231    );
 232
 233    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 234    let message_5 = message_5.unwrap();
 235    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 236    assert_eq!(
 237        messages(&context, cx),
 238        vec![
 239            (message_1.id, Role::User, 0..4),
 240            (message_3.id, Role::User, 4..5),
 241            (message_2.id, Role::User, 5..9),
 242            (message_4.id, Role::User, 9..10),
 243            (message_5.id, Role::User, 10..18),
 244        ]
 245    );
 246
 247    let (message_6, message_7) =
 248        context.update(cx, |context, cx| context.split_message(14..16, cx));
 249    let message_6 = message_6.unwrap();
 250    let message_7 = message_7.unwrap();
 251    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 252    assert_eq!(
 253        messages(&context, cx),
 254        vec![
 255            (message_1.id, Role::User, 0..4),
 256            (message_3.id, Role::User, 4..5),
 257            (message_2.id, Role::User, 5..9),
 258            (message_4.id, Role::User, 9..10),
 259            (message_5.id, Role::User, 10..14),
 260            (message_6.id, Role::User, 14..17),
 261            (message_7.id, Role::User, 17..19),
 262        ]
 263    );
 264}
 265
 266#[gpui::test]
 267fn test_messages_for_offsets(cx: &mut AppContext) {
 268    let settings_store = SettingsStore::test(cx);
 269    LanguageModelRegistry::test(cx);
 270    cx.set_global(settings_store);
 271    assistant_panel::init(cx);
 272    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 273    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 274    let context =
 275        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
 276    let buffer = context.read(cx).buffer.clone();
 277
 278    let message_1 = context.read(cx).message_anchors[0].clone();
 279    assert_eq!(
 280        messages(&context, cx),
 281        vec![(message_1.id, Role::User, 0..0)]
 282    );
 283
 284    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 285    let message_2 = context
 286        .update(cx, |context, cx| {
 287            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 288        })
 289        .unwrap();
 290    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 291
 292    let message_3 = context
 293        .update(cx, |context, cx| {
 294            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 295        })
 296        .unwrap();
 297    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 298
 299    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 300    assert_eq!(
 301        messages(&context, cx),
 302        vec![
 303            (message_1.id, Role::User, 0..4),
 304            (message_2.id, Role::User, 4..8),
 305            (message_3.id, Role::User, 8..11)
 306        ]
 307    );
 308
 309    assert_eq!(
 310        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 311        [message_1.id, message_2.id, message_3.id]
 312    );
 313    assert_eq!(
 314        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 315        [message_1.id, message_3.id]
 316    );
 317
 318    let message_4 = context
 319        .update(cx, |context, cx| {
 320            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 321        })
 322        .unwrap();
 323    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 324    assert_eq!(
 325        messages(&context, cx),
 326        vec![
 327            (message_1.id, Role::User, 0..4),
 328            (message_2.id, Role::User, 4..8),
 329            (message_3.id, Role::User, 8..12),
 330            (message_4.id, Role::User, 12..12)
 331        ]
 332    );
 333    assert_eq!(
 334        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 335        [message_1.id, message_2.id, message_3.id, message_4.id]
 336    );
 337
 338    fn message_ids_for_offsets(
 339        context: &Model<Context>,
 340        offsets: &[usize],
 341        cx: &AppContext,
 342    ) -> Vec<MessageId> {
 343        context
 344            .read(cx)
 345            .messages_for_offsets(offsets.iter().copied(), cx)
 346            .into_iter()
 347            .map(|message| message.id)
 348            .collect()
 349    }
 350}
 351
 352#[gpui::test]
 353async fn test_slash_commands(cx: &mut TestAppContext) {
 354    let settings_store = cx.update(SettingsStore::test);
 355    cx.set_global(settings_store);
 356    cx.update(LanguageModelRegistry::test);
 357    cx.update(Project::init_settings);
 358    cx.update(assistant_panel::init);
 359    let fs = FakeFs::new(cx.background_executor.clone());
 360
 361    fs.insert_tree(
 362        "/test",
 363        json!({
 364            "src": {
 365                "lib.rs": "fn one() -> usize { 1 }",
 366                "main.rs": "
 367                    use crate::one;
 368                    fn main() { one(); }
 369                ".unindent(),
 370            }
 371        }),
 372    )
 373    .await;
 374
 375    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 376    slash_command_registry.register_command(file_command::FileSlashCommand, false);
 377
 378    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 379    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 380    let context =
 381        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 382
 383    let output_ranges = Rc::new(RefCell::new(HashSet::default()));
 384    context.update(cx, |_, cx| {
 385        cx.subscribe(&context, {
 386            let ranges = output_ranges.clone();
 387            move |_, _, event, _| match event {
 388                ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
 389                    for range in removed {
 390                        ranges.borrow_mut().remove(range);
 391                    }
 392                    for command in updated {
 393                        ranges.borrow_mut().insert(command.source_range.clone());
 394                    }
 395                }
 396                _ => {}
 397            }
 398        })
 399        .detach();
 400    });
 401
 402    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 403
 404    // Insert a slash command
 405    buffer.update(cx, |buffer, cx| {
 406        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 407    });
 408    assert_text_and_output_ranges(
 409        &buffer,
 410        &output_ranges.borrow(),
 411        "
 412        «/file src/lib.rs»
 413        "
 414        .unindent()
 415        .trim_end(),
 416        cx,
 417    );
 418
 419    // Edit the argument of the slash command.
 420    buffer.update(cx, |buffer, cx| {
 421        let edit_offset = buffer.text().find("lib.rs").unwrap();
 422        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 423    });
 424    assert_text_and_output_ranges(
 425        &buffer,
 426        &output_ranges.borrow(),
 427        "
 428        «/file src/main.rs»
 429        "
 430        .unindent()
 431        .trim_end(),
 432        cx,
 433    );
 434
 435    // Edit the name of the slash command, using one that doesn't exist.
 436    buffer.update(cx, |buffer, cx| {
 437        let edit_offset = buffer.text().find("/file").unwrap();
 438        buffer.edit(
 439            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 440            None,
 441            cx,
 442        );
 443    });
 444    assert_text_and_output_ranges(
 445        &buffer,
 446        &output_ranges.borrow(),
 447        "
 448        /unknown src/main.rs
 449        "
 450        .unindent()
 451        .trim_end(),
 452        cx,
 453    );
 454
 455    #[track_caller]
 456    fn assert_text_and_output_ranges(
 457        buffer: &Model<Buffer>,
 458        ranges: &HashSet<Range<language::Anchor>>,
 459        expected_marked_text: &str,
 460        cx: &mut TestAppContext,
 461    ) {
 462        let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
 463        let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
 464            let mut ranges = ranges
 465                .iter()
 466                .map(|range| range.to_offset(buffer))
 467                .collect::<Vec<_>>();
 468            ranges.sort_by_key(|a| a.start);
 469            (buffer.text(), ranges)
 470        });
 471
 472        assert_eq!(actual_text, expected_text);
 473        assert_eq!(actual_ranges, expected_ranges);
 474    }
 475}
 476
 477#[gpui::test]
 478async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 479    cx.update(prompt_library::init);
 480    let settings_store = cx.update(SettingsStore::test);
 481    cx.set_global(settings_store);
 482    cx.update(Project::init_settings);
 483    let fs = FakeFs::new(cx.executor());
 484    fs.as_fake()
 485        .insert_tree(
 486            "/root",
 487            json!({
 488                "hello.rs": r#"
 489                    fn hello() {
 490                        println!("Hello, World!");
 491                    }
 492                "#.unindent()
 493            }),
 494        )
 495        .await;
 496    let project = Project::test(fs, [Path::new("/root")], cx).await;
 497    cx.update(LanguageModelRegistry::test);
 498
 499    let model = cx.read(|cx| {
 500        LanguageModelRegistry::read_global(cx)
 501            .active_model()
 502            .unwrap()
 503    });
 504    cx.update(assistant_panel::init);
 505    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 506
 507    // Create a new context
 508    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 509    let context = cx.new_model(|cx| {
 510        Context::local(
 511            registry.clone(),
 512            Some(project),
 513            None,
 514            prompt_builder.clone(),
 515            cx,
 516        )
 517    });
 518    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 519
 520    // Simulate user input
 521    let user_message = indoc! {r#"
 522        Please add unnecessary complexity to this code:
 523
 524        ```hello.rs
 525        fn main() {
 526            println!("Hello, World!");
 527        }
 528        ```
 529    "#};
 530    buffer.update(cx, |buffer, cx| {
 531        buffer.edit([(0..0, user_message)], None, cx);
 532    });
 533
 534    // Simulate LLM response with edit steps
 535    let llm_response = indoc! {r#"
 536        Sure, I can help you with that. Here's a step-by-step process:
 537
 538        <step>
 539        First, let's extract the greeting into a separate function:
 540
 541        ```rust
 542        fn greet() {
 543            println!("Hello, World!");
 544        }
 545
 546        fn main() {
 547            greet();
 548        }
 549        ```
 550        </step>
 551
 552        <step>
 553        Now, let's make the greeting customizable:
 554
 555        ```rust
 556        fn greet(name: &str) {
 557            println!("Hello, {}!", name);
 558        }
 559
 560        fn main() {
 561            greet("World");
 562        }
 563        ```
 564        </step>
 565
 566        These changes make the code more modular and flexible.
 567    "#};
 568
 569    // Simulate the assist method to trigger the LLM response
 570    context.update(cx, |context, cx| context.assist(cx));
 571    cx.run_until_parked();
 572
 573    // Retrieve the assistant response message's start from the context
 574    let response_start_row = context.read_with(cx, |context, cx| {
 575        let buffer = context.buffer.read(cx);
 576        context.message_anchors[1].start.to_point(buffer).row
 577    });
 578
 579    // Simulate the LLM completion
 580    model
 581        .as_fake()
 582        .stream_last_completion_response(llm_response.to_string());
 583    model.as_fake().end_last_completion_stream();
 584
 585    // Wait for the completion to be processed
 586    cx.run_until_parked();
 587
 588    // Verify that the edit steps were parsed correctly
 589    context.read_with(cx, |context, cx| {
 590        assert_eq!(
 591            workflow_steps(context, cx),
 592            vec![
 593                (
 594                    Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
 595                    WorkflowStepTestStatus::Pending
 596                ),
 597                (
 598                    Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
 599                    WorkflowStepTestStatus::Pending
 600                ),
 601            ]
 602        );
 603    });
 604
 605    model
 606        .as_fake()
 607        .respond_to_last_tool_use(tool::WorkflowStepResolutionTool {
 608            step_title: "Title".into(),
 609            suggestions: vec![tool::WorkflowSuggestionTool {
 610                path: "/root/hello.rs".into(),
 611                // Simulate a symbol name that's slightly different than our outline query
 612                kind: tool::WorkflowSuggestionToolKind::Update {
 613                    symbol: "fn main()".into(),
 614                    description: "Extract a greeting function".into(),
 615                },
 616            }],
 617        });
 618
 619    // Wait for tool use to be processed.
 620    cx.run_until_parked();
 621
 622    // Verify that the first edit step is not pending anymore.
 623    context.read_with(cx, |context, cx| {
 624        assert_eq!(
 625            workflow_steps(context, cx),
 626            vec![
 627                (
 628                    Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
 629                    WorkflowStepTestStatus::Resolved
 630                ),
 631                (
 632                    Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
 633                    WorkflowStepTestStatus::Pending
 634                ),
 635            ]
 636        );
 637    });
 638
 639    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 640    enum WorkflowStepTestStatus {
 641        Pending,
 642        Resolved,
 643        Error,
 644    }
 645
 646    fn workflow_steps(
 647        context: &Context,
 648        cx: &AppContext,
 649    ) -> Vec<(Range<Point>, WorkflowStepTestStatus)> {
 650        context
 651            .workflow_steps
 652            .iter()
 653            .map(|step| {
 654                let buffer = context.buffer.read(cx);
 655                let status = match &step.step.read(cx).resolution {
 656                    None => WorkflowStepTestStatus::Pending,
 657                    Some(Ok(_)) => WorkflowStepTestStatus::Resolved,
 658                    Some(Err(_)) => WorkflowStepTestStatus::Error,
 659                };
 660                (step.range.to_point(buffer), status)
 661            })
 662            .collect()
 663    }
 664}
 665
 666#[gpui::test]
 667async fn test_serialization(cx: &mut TestAppContext) {
 668    let settings_store = cx.update(SettingsStore::test);
 669    cx.set_global(settings_store);
 670    cx.update(LanguageModelRegistry::test);
 671    cx.update(assistant_panel::init);
 672    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 673    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 674    let context =
 675        cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
 676    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 677    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
 678    let message_1 = context.update(cx, |context, cx| {
 679        context
 680            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 681            .unwrap()
 682    });
 683    let message_2 = context.update(cx, |context, cx| {
 684        context
 685            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 686            .unwrap()
 687    });
 688    buffer.update(cx, |buffer, cx| {
 689        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 690        buffer.finalize_last_transaction();
 691    });
 692    let _message_3 = context.update(cx, |context, cx| {
 693        context
 694            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 695            .unwrap()
 696    });
 697    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 698    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 699    assert_eq!(
 700        cx.read(|cx| messages(&context, cx)),
 701        [
 702            (message_0, Role::User, 0..2),
 703            (message_1.id, Role::Assistant, 2..6),
 704            (message_2.id, Role::System, 6..6),
 705        ]
 706    );
 707
 708    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 709    let deserialized_context = cx.new_model(|cx| {
 710        Context::deserialize(
 711            serialized_context,
 712            Default::default(),
 713            registry.clone(),
 714            prompt_builder.clone(),
 715            None,
 716            None,
 717            cx,
 718        )
 719    });
 720    let deserialized_buffer =
 721        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
 722    assert_eq!(
 723        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 724        "a\nb\nc\n"
 725    );
 726    assert_eq!(
 727        cx.read(|cx| messages(&deserialized_context, cx)),
 728        [
 729            (message_0, Role::User, 0..2),
 730            (message_1.id, Role::Assistant, 2..6),
 731            (message_2.id, Role::System, 6..6),
 732        ]
 733    );
 734}
 735
 736#[gpui::test(iterations = 100)]
 737async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 738    let min_peers = env::var("MIN_PEERS")
 739        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 740        .unwrap_or(2);
 741    let max_peers = env::var("MAX_PEERS")
 742        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 743        .unwrap_or(5);
 744    let operations = env::var("OPERATIONS")
 745        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 746        .unwrap_or(50);
 747
 748    let settings_store = cx.update(SettingsStore::test);
 749    cx.set_global(settings_store);
 750    cx.update(LanguageModelRegistry::test);
 751
 752    cx.update(assistant_panel::init);
 753    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 754    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 755    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 756    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 757
 758    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 759    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 760    let mut contexts = Vec::new();
 761
 762    let num_peers = rng.gen_range(min_peers..=max_peers);
 763    let context_id = ContextId::new();
 764    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 765    for i in 0..num_peers {
 766        let context = cx.new_model(|cx| {
 767            Context::new(
 768                context_id.clone(),
 769                i as ReplicaId,
 770                language::Capability::ReadWrite,
 771                registry.clone(),
 772                prompt_builder.clone(),
 773                None,
 774                None,
 775                cx,
 776            )
 777        });
 778
 779        cx.update(|cx| {
 780            cx.subscribe(&context, {
 781                let network = network.clone();
 782                move |_, event, _| {
 783                    if let ContextEvent::Operation(op) = event {
 784                        network
 785                            .lock()
 786                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
 787                    }
 788                }
 789            })
 790            .detach();
 791        });
 792
 793        contexts.push(context);
 794        network.lock().add_peer(i as ReplicaId);
 795    }
 796
 797    let mut mutation_count = operations;
 798
 799    while mutation_count > 0
 800        || !network.lock().is_idle()
 801        || network.lock().contains_disconnected_peers()
 802    {
 803        let context_index = rng.gen_range(0..contexts.len());
 804        let context = &contexts[context_index];
 805
 806        match rng.gen_range(0..100) {
 807            0..=29 if mutation_count > 0 => {
 808                log::info!("Context {}: edit buffer", context_index);
 809                context.update(cx, |context, cx| {
 810                    context
 811                        .buffer
 812                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
 813                });
 814                mutation_count -= 1;
 815            }
 816            30..=44 if mutation_count > 0 => {
 817                context.update(cx, |context, cx| {
 818                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
 819                    log::info!("Context {}: split message at {:?}", context_index, range);
 820                    context.split_message(range, cx);
 821                });
 822                mutation_count -= 1;
 823            }
 824            45..=59 if mutation_count > 0 => {
 825                context.update(cx, |context, cx| {
 826                    if let Some(message) = context.messages(cx).choose(&mut rng) {
 827                        let role = *[Role::User, Role::Assistant, Role::System]
 828                            .choose(&mut rng)
 829                            .unwrap();
 830                        log::info!(
 831                            "Context {}: insert message after {:?} with {:?}",
 832                            context_index,
 833                            message.id,
 834                            role
 835                        );
 836                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
 837                    }
 838                });
 839                mutation_count -= 1;
 840            }
 841            60..=74 if mutation_count > 0 => {
 842                context.update(cx, |context, cx| {
 843                    let command_text = "/".to_string()
 844                        + slash_commands
 845                            .command_names()
 846                            .choose(&mut rng)
 847                            .unwrap()
 848                            .clone()
 849                            .as_ref();
 850
 851                    let command_range = context.buffer.update(cx, |buffer, cx| {
 852                        let offset = buffer.random_byte_range(0, &mut rng).start;
 853                        buffer.edit(
 854                            [(offset..offset, format!("\n{}\n", command_text))],
 855                            None,
 856                            cx,
 857                        );
 858                        offset + 1..offset + 1 + command_text.len()
 859                    });
 860
 861                    let output_len = rng.gen_range(1..=10);
 862                    let output_text = RandomCharIter::new(&mut rng)
 863                        .filter(|c| *c != '\r')
 864                        .take(output_len)
 865                        .collect::<String>();
 866
 867                    let num_sections = rng.gen_range(0..=3);
 868                    let mut sections = Vec::with_capacity(num_sections);
 869                    for _ in 0..num_sections {
 870                        let section_start = rng.gen_range(0..output_len);
 871                        let section_end = rng.gen_range(section_start..=output_len);
 872                        sections.push(SlashCommandOutputSection {
 873                            range: section_start..section_end,
 874                            icon: ui::IconName::Ai,
 875                            label: "section".into(),
 876                        });
 877                    }
 878
 879                    log::info!(
 880                        "Context {}: insert slash command output at {:?} with {:?}",
 881                        context_index,
 882                        command_range,
 883                        sections
 884                    );
 885
 886                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
 887                        ..context.buffer.read(cx).anchor_after(command_range.end);
 888                    context.insert_command_output(
 889                        command_range,
 890                        Task::ready(Ok(SlashCommandOutput {
 891                            text: output_text,
 892                            sections,
 893                            run_commands_in_text: false,
 894                        })),
 895                        true,
 896                        false,
 897                        cx,
 898                    );
 899                });
 900                cx.run_until_parked();
 901                mutation_count -= 1;
 902            }
 903            75..=84 if mutation_count > 0 => {
 904                context.update(cx, |context, cx| {
 905                    if let Some(message) = context.messages(cx).choose(&mut rng) {
 906                        let new_status = match rng.gen_range(0..3) {
 907                            0 => MessageStatus::Done,
 908                            1 => MessageStatus::Pending,
 909                            _ => MessageStatus::Error(SharedString::from("Random error")),
 910                        };
 911                        log::info!(
 912                            "Context {}: update message {:?} status to {:?}",
 913                            context_index,
 914                            message.id,
 915                            new_status
 916                        );
 917                        context.update_metadata(message.id, cx, |metadata| {
 918                            metadata.status = new_status;
 919                        });
 920                    }
 921                });
 922                mutation_count -= 1;
 923            }
 924            _ => {
 925                let replica_id = context_index as ReplicaId;
 926                if network.lock().is_disconnected(replica_id) {
 927                    network.lock().reconnect_peer(replica_id, 0);
 928
 929                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
 930                        let host_context = &contexts[0].read(cx);
 931                        let guest_context = context.read(cx);
 932                        (
 933                            guest_context.serialize_ops(&host_context.version(cx), cx),
 934                            host_context.serialize_ops(&guest_context.version(cx), cx),
 935                        )
 936                    });
 937                    let ops_to_send = ops_to_send.await;
 938                    let ops_to_receive = ops_to_receive
 939                        .await
 940                        .into_iter()
 941                        .map(ContextOperation::from_proto)
 942                        .collect::<Result<Vec<_>>>()
 943                        .unwrap();
 944                    log::info!(
 945                        "Context {}: reconnecting. Sent {} operations, received {} operations",
 946                        context_index,
 947                        ops_to_send.len(),
 948                        ops_to_receive.len()
 949                    );
 950
 951                    network.lock().broadcast(replica_id, ops_to_send);
 952                    context
 953                        .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
 954                        .unwrap();
 955                } else if rng.gen_bool(0.1) && replica_id != 0 {
 956                    log::info!("Context {}: disconnecting", context_index);
 957                    network.lock().disconnect_peer(replica_id);
 958                } else if network.lock().has_unreceived(replica_id) {
 959                    log::info!("Context {}: applying operations", context_index);
 960                    let ops = network.lock().receive(replica_id);
 961                    let ops = ops
 962                        .into_iter()
 963                        .map(ContextOperation::from_proto)
 964                        .collect::<Result<Vec<_>>>()
 965                        .unwrap();
 966                    context
 967                        .update(cx, |context, cx| context.apply_ops(ops, cx))
 968                        .unwrap();
 969                }
 970            }
 971        }
 972    }
 973
 974    cx.read(|cx| {
 975        let first_context = contexts[0].read(cx);
 976        for context in &contexts[1..] {
 977            let context = context.read(cx);
 978            assert!(context.pending_ops.is_empty());
 979            assert_eq!(
 980                context.buffer.read(cx).text(),
 981                first_context.buffer.read(cx).text(),
 982                "Context {} text != Context 0 text",
 983                context.buffer.read(cx).replica_id()
 984            );
 985            assert_eq!(
 986                context.message_anchors,
 987                first_context.message_anchors,
 988                "Context {} messages != Context 0 messages",
 989                context.buffer.read(cx).replica_id()
 990            );
 991            assert_eq!(
 992                context.messages_metadata,
 993                first_context.messages_metadata,
 994                "Context {} message metadata != Context 0 message metadata",
 995                context.buffer.read(cx).replica_id()
 996            );
 997            assert_eq!(
 998                context.slash_command_output_sections,
 999                first_context.slash_command_output_sections,
1000                "Context {} slash command output sections != Context 0 slash command output sections",
1001                context.buffer.read(cx).replica_id()
1002            );
1003        }
1004    });
1005}
1006
1007#[gpui::test]
1008fn test_mark_cache_anchors(cx: &mut AppContext) {
1009    let settings_store = SettingsStore::test(cx);
1010    LanguageModelRegistry::test(cx);
1011    cx.set_global(settings_store);
1012    assistant_panel::init(cx);
1013    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1014    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1015    let context =
1016        cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1017    let buffer = context.read(cx).buffer.clone();
1018
1019    // Create a test cache configuration
1020    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1021        max_cache_anchors: 3,
1022        should_speculate: true,
1023        min_total_token: 10,
1024    });
1025
1026    let message_1 = context.read(cx).message_anchors[0].clone();
1027
1028    context.update(cx, |context, cx| {
1029        context.mark_cache_anchors(cache_configuration, false, cx)
1030    });
1031
1032    assert_eq!(
1033        messages_cache(&context, cx)
1034            .iter()
1035            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1036            .count(),
1037        0,
1038        "Empty messages should not have any cache anchors."
1039    );
1040
1041    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1042    let message_2 = context
1043        .update(cx, |context, cx| {
1044            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1045        })
1046        .unwrap();
1047
1048    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1049    let message_3 = context
1050        .update(cx, |context, cx| {
1051            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1052        })
1053        .unwrap();
1054    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1055
1056    context.update(cx, |context, cx| {
1057        context.mark_cache_anchors(cache_configuration, false, cx)
1058    });
1059    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1060    assert_eq!(
1061        messages_cache(&context, cx)
1062            .iter()
1063            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1064            .count(),
1065        0,
1066        "Messages should not be marked for cache before going over the token minimum."
1067    );
1068    context.update(cx, |context, _| {
1069        context.token_count = Some(20);
1070    });
1071
1072    context.update(cx, |context, cx| {
1073        context.mark_cache_anchors(cache_configuration, true, cx)
1074    });
1075    assert_eq!(
1076        messages_cache(&context, cx)
1077            .iter()
1078            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1079            .collect::<Vec<bool>>(),
1080        vec![true, true, false],
1081        "Last message should not be an anchor on speculative request."
1082    );
1083
1084    context
1085        .update(cx, |context, cx| {
1086            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1087        })
1088        .unwrap();
1089
1090    context.update(cx, |context, cx| {
1091        context.mark_cache_anchors(cache_configuration, false, cx)
1092    });
1093    assert_eq!(
1094        messages_cache(&context, cx)
1095            .iter()
1096            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1097            .collect::<Vec<bool>>(),
1098        vec![false, true, true, false],
1099        "Most recent message should also be cached if not a speculative request."
1100    );
1101    context.update(cx, |context, cx| {
1102        context.update_cache_status_for_completion(cx)
1103    });
1104    assert_eq!(
1105        messages_cache(&context, cx)
1106            .iter()
1107            .map(|(_, cache)| cache
1108                .as_ref()
1109                .map_or(None, |cache| Some(cache.status.clone())))
1110            .collect::<Vec<Option<CacheStatus>>>(),
1111        vec![
1112            Some(CacheStatus::Cached),
1113            Some(CacheStatus::Cached),
1114            Some(CacheStatus::Cached),
1115            None
1116        ],
1117        "All user messages prior to anchor should be marked as cached."
1118    );
1119
1120    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1121    context.update(cx, |context, cx| {
1122        context.mark_cache_anchors(cache_configuration, false, cx)
1123    });
1124    assert_eq!(
1125        messages_cache(&context, cx)
1126            .iter()
1127            .map(|(_, cache)| cache
1128                .as_ref()
1129                .map_or(None, |cache| Some(cache.status.clone())))
1130            .collect::<Vec<Option<CacheStatus>>>(),
1131        vec![
1132            Some(CacheStatus::Cached),
1133            Some(CacheStatus::Cached),
1134            Some(CacheStatus::Pending),
1135            None
1136        ],
1137        "Modifying a message should invalidate it's cache but leave previous messages."
1138    );
1139    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1140    context.update(cx, |context, cx| {
1141        context.mark_cache_anchors(cache_configuration, false, cx)
1142    });
1143    assert_eq!(
1144        messages_cache(&context, cx)
1145            .iter()
1146            .map(|(_, cache)| cache
1147                .as_ref()
1148                .map_or(None, |cache| Some(cache.status.clone())))
1149            .collect::<Vec<Option<CacheStatus>>>(),
1150        vec![
1151            Some(CacheStatus::Pending),
1152            Some(CacheStatus::Pending),
1153            Some(CacheStatus::Pending),
1154            None
1155        ],
1156        "Modifying a message should invalidate all future messages."
1157    );
1158}
1159
1160fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1161    context
1162        .read(cx)
1163        .messages(cx)
1164        .map(|message| (message.id, message.role, message.offset_range))
1165        .collect()
1166}
1167
1168fn messages_cache(
1169    context: &Model<Context>,
1170    cx: &AppContext,
1171) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1172    context
1173        .read(cx)
1174        .messages(cx)
1175        .map(|message| (message.id, message.cache.clone()))
1176        .collect()
1177}
1178
1179#[derive(Clone)]
1180struct FakeSlashCommand(String);
1181
1182impl SlashCommand for FakeSlashCommand {
1183    fn name(&self) -> String {
1184        self.0.clone()
1185    }
1186
1187    fn description(&self) -> String {
1188        format!("Fake slash command: {}", self.0)
1189    }
1190
1191    fn menu_text(&self) -> String {
1192        format!("Run fake command: {}", self.0)
1193    }
1194
1195    fn complete_argument(
1196        self: Arc<Self>,
1197        _arguments: &[String],
1198        _cancel: Arc<AtomicBool>,
1199        _workspace: Option<WeakView<Workspace>>,
1200        _cx: &mut WindowContext,
1201    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1202        Task::ready(Ok(vec![]))
1203    }
1204
1205    fn requires_argument(&self) -> bool {
1206        false
1207    }
1208
1209    fn run(
1210        self: Arc<Self>,
1211        _arguments: &[String],
1212        _workspace: WeakView<Workspace>,
1213        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1214        _cx: &mut WindowContext,
1215    ) -> Task<Result<SlashCommandOutput>> {
1216        Task::ready(Ok(SlashCommandOutput {
1217            text: format!("Executed fake command: {}", self.0),
1218            sections: vec![],
1219            run_commands_in_text: false,
1220        }))
1221    }
1222}