context_tests.rs

   1use crate::{
   2    AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
   3    ContextOperation, ContextSummary, InvokedSlashCommandId, MessageCacheMetadata, MessageId,
   4    MessageStatus, RequestType,
   5};
   6use anyhow::Result;
   7use assistant_slash_command::{
   8    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   9    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
  10};
  11use assistant_slash_commands::FileSlashCommand;
  12use collections::{HashMap, HashSet};
  13use fs::FakeFs;
  14use futures::{
  15    channel::mpsc,
  16    stream::{self, StreamExt},
  17};
  18use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  19use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  20use language_model::{
  21    ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
  22    fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
  23};
  24use parking_lot::Mutex;
  25use pretty_assertions::assert_eq;
  26use project::Project;
  27use prompt_store::PromptBuilder;
  28use rand::prelude::*;
  29use serde_json::json;
  30use settings::SettingsStore;
  31use std::{
  32    cell::RefCell,
  33    env,
  34    ops::Range,
  35    path::Path,
  36    rc::Rc,
  37    sync::{Arc, atomic::AtomicBool},
  38};
  39use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
  40use ui::{IconName, Window};
  41use unindent::Unindent;
  42use util::{
  43    RandomCharIter,
  44    test::{generate_marked_text, marked_text_ranges},
  45};
  46use workspace::Workspace;
  47
  48#[gpui::test]
  49fn test_inserting_and_removing_messages(cx: &mut App) {
  50    init_test(cx);
  51
  52    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  53    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  54    let context = cx.new(|cx| {
  55        AssistantContext::local(
  56            registry,
  57            None,
  58            None,
  59            prompt_builder.clone(),
  60            Arc::new(SlashCommandWorkingSet::default()),
  61            cx,
  62        )
  63    });
  64    let buffer = context.read(cx).buffer.clone();
  65
  66    let message_1 = context.read(cx).message_anchors[0].clone();
  67    assert_eq!(
  68        messages(&context, cx),
  69        vec![(message_1.id, Role::User, 0..0)]
  70    );
  71
  72    let message_2 = context.update(cx, |context, cx| {
  73        context
  74            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  75            .unwrap()
  76    });
  77    assert_eq!(
  78        messages(&context, cx),
  79        vec![
  80            (message_1.id, Role::User, 0..1),
  81            (message_2.id, Role::Assistant, 1..1)
  82        ]
  83    );
  84
  85    buffer.update(cx, |buffer, cx| {
  86        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  87    });
  88    assert_eq!(
  89        messages(&context, cx),
  90        vec![
  91            (message_1.id, Role::User, 0..2),
  92            (message_2.id, Role::Assistant, 2..3)
  93        ]
  94    );
  95
  96    let message_3 = context.update(cx, |context, cx| {
  97        context
  98            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  99            .unwrap()
 100    });
 101    assert_eq!(
 102        messages(&context, cx),
 103        vec![
 104            (message_1.id, Role::User, 0..2),
 105            (message_2.id, Role::Assistant, 2..4),
 106            (message_3.id, Role::User, 4..4)
 107        ]
 108    );
 109
 110    let message_4 = context.update(cx, |context, cx| {
 111        context
 112            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 113            .unwrap()
 114    });
 115    assert_eq!(
 116        messages(&context, cx),
 117        vec![
 118            (message_1.id, Role::User, 0..2),
 119            (message_2.id, Role::Assistant, 2..4),
 120            (message_4.id, Role::User, 4..5),
 121            (message_3.id, Role::User, 5..5),
 122        ]
 123    );
 124
 125    buffer.update(cx, |buffer, cx| {
 126        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 127    });
 128    assert_eq!(
 129        messages(&context, cx),
 130        vec![
 131            (message_1.id, Role::User, 0..2),
 132            (message_2.id, Role::Assistant, 2..4),
 133            (message_4.id, Role::User, 4..6),
 134            (message_3.id, Role::User, 6..7),
 135        ]
 136    );
 137
 138    // Deleting across message boundaries merges the messages.
 139    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 140    assert_eq!(
 141        messages(&context, cx),
 142        vec![
 143            (message_1.id, Role::User, 0..3),
 144            (message_3.id, Role::User, 3..4),
 145        ]
 146    );
 147
 148    // Undoing the deletion should also undo the merge.
 149    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 150    assert_eq!(
 151        messages(&context, cx),
 152        vec![
 153            (message_1.id, Role::User, 0..2),
 154            (message_2.id, Role::Assistant, 2..4),
 155            (message_4.id, Role::User, 4..6),
 156            (message_3.id, Role::User, 6..7),
 157        ]
 158    );
 159
 160    // Redoing the deletion should also redo the merge.
 161    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 162    assert_eq!(
 163        messages(&context, cx),
 164        vec![
 165            (message_1.id, Role::User, 0..3),
 166            (message_3.id, Role::User, 3..4),
 167        ]
 168    );
 169
 170    // Ensure we can still insert after a merged message.
 171    let message_5 = context.update(cx, |context, cx| {
 172        context
 173            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 174            .unwrap()
 175    });
 176    assert_eq!(
 177        messages(&context, cx),
 178        vec![
 179            (message_1.id, Role::User, 0..3),
 180            (message_5.id, Role::System, 3..4),
 181            (message_3.id, Role::User, 4..5)
 182        ]
 183    );
 184}
 185
 186#[gpui::test]
 187fn test_message_splitting(cx: &mut App) {
 188    init_test(cx);
 189
 190    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 191
 192    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 193    let context = cx.new(|cx| {
 194        AssistantContext::local(
 195            registry.clone(),
 196            None,
 197            None,
 198            prompt_builder.clone(),
 199            Arc::new(SlashCommandWorkingSet::default()),
 200            cx,
 201        )
 202    });
 203    let buffer = context.read(cx).buffer.clone();
 204
 205    let message_1 = context.read(cx).message_anchors[0].clone();
 206    assert_eq!(
 207        messages(&context, cx),
 208        vec![(message_1.id, Role::User, 0..0)]
 209    );
 210
 211    buffer.update(cx, |buffer, cx| {
 212        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 213    });
 214
 215    let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 216    let message_2 = message_2.unwrap();
 217
 218    // We recycle newlines in the middle of a split message
 219    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 220    assert_eq!(
 221        messages(&context, cx),
 222        vec![
 223            (message_1.id, Role::User, 0..4),
 224            (message_2.id, Role::User, 4..16),
 225        ]
 226    );
 227
 228    let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
 229    let message_3 = message_3.unwrap();
 230
 231    // We don't recycle newlines at the end of a split message
 232    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 233    assert_eq!(
 234        messages(&context, cx),
 235        vec![
 236            (message_1.id, Role::User, 0..4),
 237            (message_3.id, Role::User, 4..5),
 238            (message_2.id, Role::User, 5..17),
 239        ]
 240    );
 241
 242    let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 243    let message_4 = message_4.unwrap();
 244    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 245    assert_eq!(
 246        messages(&context, cx),
 247        vec![
 248            (message_1.id, Role::User, 0..4),
 249            (message_3.id, Role::User, 4..5),
 250            (message_2.id, Role::User, 5..9),
 251            (message_4.id, Role::User, 9..17),
 252        ]
 253    );
 254
 255    let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
 256    let message_5 = message_5.unwrap();
 257    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 258    assert_eq!(
 259        messages(&context, cx),
 260        vec![
 261            (message_1.id, Role::User, 0..4),
 262            (message_3.id, Role::User, 4..5),
 263            (message_2.id, Role::User, 5..9),
 264            (message_4.id, Role::User, 9..10),
 265            (message_5.id, Role::User, 10..18),
 266        ]
 267    );
 268
 269    let (message_6, message_7) =
 270        context.update(cx, |context, cx| context.split_message(14..16, cx));
 271    let message_6 = message_6.unwrap();
 272    let message_7 = message_7.unwrap();
 273    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 274    assert_eq!(
 275        messages(&context, cx),
 276        vec![
 277            (message_1.id, Role::User, 0..4),
 278            (message_3.id, Role::User, 4..5),
 279            (message_2.id, Role::User, 5..9),
 280            (message_4.id, Role::User, 9..10),
 281            (message_5.id, Role::User, 10..14),
 282            (message_6.id, Role::User, 14..17),
 283            (message_7.id, Role::User, 17..19),
 284        ]
 285    );
 286}
 287
 288#[gpui::test]
 289fn test_messages_for_offsets(cx: &mut App) {
 290    init_test(cx);
 291
 292    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 293    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 294    let context = cx.new(|cx| {
 295        AssistantContext::local(
 296            registry,
 297            None,
 298            None,
 299            prompt_builder.clone(),
 300            Arc::new(SlashCommandWorkingSet::default()),
 301            cx,
 302        )
 303    });
 304    let buffer = context.read(cx).buffer.clone();
 305
 306    let message_1 = context.read(cx).message_anchors[0].clone();
 307    assert_eq!(
 308        messages(&context, cx),
 309        vec![(message_1.id, Role::User, 0..0)]
 310    );
 311
 312    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 313    let message_2 = context
 314        .update(cx, |context, cx| {
 315            context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 316        })
 317        .unwrap();
 318    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 319
 320    let message_3 = context
 321        .update(cx, |context, cx| {
 322            context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 323        })
 324        .unwrap();
 325    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 326
 327    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 328    assert_eq!(
 329        messages(&context, cx),
 330        vec![
 331            (message_1.id, Role::User, 0..4),
 332            (message_2.id, Role::User, 4..8),
 333            (message_3.id, Role::User, 8..11)
 334        ]
 335    );
 336
 337    assert_eq!(
 338        message_ids_for_offsets(&context, &[0, 4, 9], cx),
 339        [message_1.id, message_2.id, message_3.id]
 340    );
 341    assert_eq!(
 342        message_ids_for_offsets(&context, &[0, 1, 11], cx),
 343        [message_1.id, message_3.id]
 344    );
 345
 346    let message_4 = context
 347        .update(cx, |context, cx| {
 348            context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 349        })
 350        .unwrap();
 351    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 352    assert_eq!(
 353        messages(&context, cx),
 354        vec![
 355            (message_1.id, Role::User, 0..4),
 356            (message_2.id, Role::User, 4..8),
 357            (message_3.id, Role::User, 8..12),
 358            (message_4.id, Role::User, 12..12)
 359        ]
 360    );
 361    assert_eq!(
 362        message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
 363        [message_1.id, message_2.id, message_3.id, message_4.id]
 364    );
 365
 366    fn message_ids_for_offsets(
 367        context: &Entity<AssistantContext>,
 368        offsets: &[usize],
 369        cx: &App,
 370    ) -> Vec<MessageId> {
 371        context
 372            .read(cx)
 373            .messages_for_offsets(offsets.iter().copied(), cx)
 374            .into_iter()
 375            .map(|message| message.id)
 376            .collect()
 377    }
 378}
 379
 380#[gpui::test]
 381async fn test_slash_commands(cx: &mut TestAppContext) {
 382    cx.update(init_test);
 383
 384    let fs = FakeFs::new(cx.background_executor.clone());
 385
 386    fs.insert_tree(
 387        "/test",
 388        json!({
 389            "src": {
 390                "lib.rs": "fn one() -> usize { 1 }",
 391                "main.rs": "
 392                    use crate::one;
 393                    fn main() { one(); }
 394                ".unindent(),
 395            }
 396        }),
 397    )
 398    .await;
 399
 400    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 401    slash_command_registry.register_command(FileSlashCommand, false);
 402
 403    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 404    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 405    let context = cx.new(|cx| {
 406        AssistantContext::local(
 407            registry.clone(),
 408            None,
 409            None,
 410            prompt_builder.clone(),
 411            Arc::new(SlashCommandWorkingSet::default()),
 412            cx,
 413        )
 414    });
 415
 416    #[derive(Default)]
 417    struct ContextRanges {
 418        parsed_commands: HashSet<Range<language::Anchor>>,
 419        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 420        output_sections: HashSet<Range<language::Anchor>>,
 421    }
 422
 423    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 424    context.update(cx, |_, cx| {
 425        cx.subscribe(&context, {
 426            let context_ranges = context_ranges.clone();
 427            move |context, _, event, _| {
 428                let mut context_ranges = context_ranges.borrow_mut();
 429                match event {
 430                    ContextEvent::InvokedSlashCommandChanged { command_id } => {
 431                        let command = context.invoked_slash_command(command_id).unwrap();
 432                        context_ranges
 433                            .command_outputs
 434                            .insert(*command_id, command.range.clone());
 435                    }
 436                    ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 437                        for range in removed {
 438                            context_ranges.parsed_commands.remove(range);
 439                        }
 440                        for command in updated {
 441                            context_ranges
 442                                .parsed_commands
 443                                .insert(command.source_range.clone());
 444                        }
 445                    }
 446                    ContextEvent::SlashCommandOutputSectionAdded { section } => {
 447                        context_ranges.output_sections.insert(section.range.clone());
 448                    }
 449                    _ => {}
 450                }
 451            }
 452        })
 453        .detach();
 454    });
 455
 456    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
 457
 458    // Insert a slash command
 459    buffer.update(cx, |buffer, cx| {
 460        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 461    });
 462    assert_text_and_context_ranges(
 463        &buffer,
 464        &context_ranges,
 465        &"
 466        «/file src/lib.rs»"
 467            .unindent(),
 468        cx,
 469    );
 470
 471    // Edit the argument of the slash command.
 472    buffer.update(cx, |buffer, cx| {
 473        let edit_offset = buffer.text().find("lib.rs").unwrap();
 474        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 475    });
 476    assert_text_and_context_ranges(
 477        &buffer,
 478        &context_ranges,
 479        &"
 480        «/file src/main.rs»"
 481            .unindent(),
 482        cx,
 483    );
 484
 485    // Edit the name of the slash command, using one that doesn't exist.
 486    buffer.update(cx, |buffer, cx| {
 487        let edit_offset = buffer.text().find("/file").unwrap();
 488        buffer.edit(
 489            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 490            None,
 491            cx,
 492        );
 493    });
 494    assert_text_and_context_ranges(
 495        &buffer,
 496        &context_ranges,
 497        &"
 498        /unknown src/main.rs"
 499            .unindent(),
 500        cx,
 501    );
 502
 503    // Undoing the insertion of an non-existent slash command resorts the previous one.
 504    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 505    assert_text_and_context_ranges(
 506        &buffer,
 507        &context_ranges,
 508        &"
 509        «/file src/main.rs»"
 510            .unindent(),
 511        cx,
 512    );
 513
 514    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 515    context.update(cx, |context, cx| {
 516        let command_source_range = context.parsed_slash_commands[0].source_range.clone();
 517        context.insert_command_output(
 518            command_source_range,
 519            "file",
 520            Task::ready(Ok(command_output_rx.boxed())),
 521            true,
 522            cx,
 523        );
 524    });
 525    assert_text_and_context_ranges(
 526        &buffer,
 527        &context_ranges,
 528        &"
 529        ⟦«/file src/main.rs»
 530        …⟧
 531        "
 532        .unindent(),
 533        cx,
 534    );
 535
 536    command_output_tx
 537        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 538            icon: IconName::Ai,
 539            label: "src/main.rs".into(),
 540            metadata: None,
 541        }))
 542        .unwrap();
 543    command_output_tx
 544        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 545        .unwrap();
 546    cx.run_until_parked();
 547    assert_text_and_context_ranges(
 548        &buffer,
 549        &context_ranges,
 550        &"
 551        ⟦«/file src/main.rs»
 552        src/main.rs…⟧
 553        "
 554        .unindent(),
 555        cx,
 556    );
 557
 558    command_output_tx
 559        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 560        .unwrap();
 561    cx.run_until_parked();
 562    assert_text_and_context_ranges(
 563        &buffer,
 564        &context_ranges,
 565        &"
 566        ⟦«/file src/main.rs»
 567        src/main.rs
 568        fn main() {}…⟧
 569        "
 570        .unindent(),
 571        cx,
 572    );
 573
 574    command_output_tx
 575        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 576        .unwrap();
 577    cx.run_until_parked();
 578    assert_text_and_context_ranges(
 579        &buffer,
 580        &context_ranges,
 581        &"
 582        ⟦«/file src/main.rs»
 583        ⟪src/main.rs
 584        fn main() {}⟫…⟧
 585        "
 586        .unindent(),
 587        cx,
 588    );
 589
 590    drop(command_output_tx);
 591    cx.run_until_parked();
 592    assert_text_and_context_ranges(
 593        &buffer,
 594        &context_ranges,
 595        &"
 596        ⟦⟪src/main.rs
 597        fn main() {}⟫⟧
 598        "
 599        .unindent(),
 600        cx,
 601    );
 602
 603    #[track_caller]
 604    fn assert_text_and_context_ranges(
 605        buffer: &Entity<Buffer>,
 606        ranges: &RefCell<ContextRanges>,
 607        expected_marked_text: &str,
 608        cx: &mut TestAppContext,
 609    ) {
 610        let mut actual_marked_text = String::new();
 611        buffer.update(cx, |buffer, _| {
 612            struct Endpoint {
 613                offset: usize,
 614                marker: char,
 615            }
 616
 617            let ranges = ranges.borrow();
 618            let mut endpoints = Vec::new();
 619            for range in ranges.command_outputs.values() {
 620                endpoints.push(Endpoint {
 621                    offset: range.start.to_offset(buffer),
 622                    marker: '',
 623                });
 624            }
 625            for range in ranges.parsed_commands.iter() {
 626                endpoints.push(Endpoint {
 627                    offset: range.start.to_offset(buffer),
 628                    marker: '«',
 629                });
 630            }
 631            for range in ranges.output_sections.iter() {
 632                endpoints.push(Endpoint {
 633                    offset: range.start.to_offset(buffer),
 634                    marker: '',
 635                });
 636            }
 637
 638            for range in ranges.output_sections.iter() {
 639                endpoints.push(Endpoint {
 640                    offset: range.end.to_offset(buffer),
 641                    marker: '',
 642                });
 643            }
 644            for range in ranges.parsed_commands.iter() {
 645                endpoints.push(Endpoint {
 646                    offset: range.end.to_offset(buffer),
 647                    marker: '»',
 648                });
 649            }
 650            for range in ranges.command_outputs.values() {
 651                endpoints.push(Endpoint {
 652                    offset: range.end.to_offset(buffer),
 653                    marker: '',
 654                });
 655            }
 656
 657            endpoints.sort_by_key(|endpoint| endpoint.offset);
 658            let mut offset = 0;
 659            for endpoint in endpoints {
 660                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 661                actual_marked_text.push(endpoint.marker);
 662                offset = endpoint.offset;
 663            }
 664            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 665        });
 666
 667        assert_eq!(actual_marked_text, expected_marked_text);
 668    }
 669}
 670
 671#[gpui::test]
 672async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
 673    cx.update(|cx| {
 674        init_test(cx);
 675        cx.update_global(|settings_store: &mut SettingsStore, cx| {
 676            settings_store
 677                .set_user_settings(
 678                    r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
 679                    cx,
 680                )
 681                .unwrap()
 682        })
 683    });
 684    let fs = FakeFs::new(cx.executor());
 685    let project = Project::test(fs, [Path::new("/root")], cx).await;
 686
 687    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 688
 689    // Create a new context
 690    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 691    let context = cx.new(|cx| {
 692        AssistantContext::local(
 693            registry.clone(),
 694            Some(project),
 695            None,
 696            prompt_builder.clone(),
 697            Arc::new(SlashCommandWorkingSet::default()),
 698            cx,
 699        )
 700    });
 701
 702    // Insert an assistant message to simulate a response.
 703    let assistant_message_id = context.update(cx, |context, cx| {
 704        let user_message_id = context.messages(cx).next().unwrap().id;
 705        context
 706            .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
 707            .unwrap()
 708            .id
 709    });
 710
 711    // No edit tags
 712    edit(
 713        &context,
 714        "
 715
 716        «one
 717        two
 718        »",
 719        cx,
 720    );
 721    expect_patches(
 722        &context,
 723        "
 724
 725        one
 726        two
 727        ",
 728        &[],
 729        cx,
 730    );
 731
 732    // Partial edit step tag is added
 733    edit(
 734        &context,
 735        "
 736
 737        one
 738        two
 739        «
 740        <patch»",
 741        cx,
 742    );
 743    expect_patches(
 744        &context,
 745        "
 746
 747        one
 748        two
 749
 750        <patch",
 751        &[],
 752        cx,
 753    );
 754
 755    // The rest of the step tag is added. The unclosed
 756    // step is treated as incomplete.
 757    edit(
 758        &context,
 759        "
 760
 761        one
 762        two
 763
 764        <patch«>
 765        <edit>»",
 766        cx,
 767    );
 768    expect_patches(
 769        &context,
 770        "
 771
 772        one
 773        two
 774
 775        «<patch>
 776        <edit>»",
 777        &[&[]],
 778        cx,
 779    );
 780
 781    // The full patch is added
 782    edit(
 783        &context,
 784        "
 785
 786        one
 787        two
 788
 789        <patch>
 790        <edit>«
 791        <description>add a `two` function</description>
 792        <path>src/lib.rs</path>
 793        <operation>insert_after</operation>
 794        <old_text>fn one</old_text>
 795        <new_text>
 796        fn two() {}
 797        </new_text>
 798        </edit>
 799        </patch>
 800
 801        also,»",
 802        cx,
 803    );
 804    expect_patches(
 805        &context,
 806        "
 807
 808        one
 809        two
 810
 811        «<patch>
 812        <edit>
 813        <description>add a `two` function</description>
 814        <path>src/lib.rs</path>
 815        <operation>insert_after</operation>
 816        <old_text>fn one</old_text>
 817        <new_text>
 818        fn two() {}
 819        </new_text>
 820        </edit>
 821        </patch>
 822        »
 823        also,",
 824        &[&[AssistantEdit {
 825            path: "src/lib.rs".into(),
 826            kind: AssistantEditKind::InsertAfter {
 827                old_text: "fn one".into(),
 828                new_text: "fn two() {}".into(),
 829                description: Some("add a `two` function".into()),
 830            },
 831        }]],
 832        cx,
 833    );
 834
 835    // The step is manually edited.
 836    edit(
 837        &context,
 838        "
 839
 840        one
 841        two
 842
 843        <patch>
 844        <edit>
 845        <description>add a `two` function</description>
 846        <path>src/lib.rs</path>
 847        <operation>insert_after</operation>
 848        <old_text>«fn zero»</old_text>
 849        <new_text>
 850        fn two() {}
 851        </new_text>
 852        </edit>
 853        </patch>
 854
 855        also,",
 856        cx,
 857    );
 858    expect_patches(
 859        &context,
 860        "
 861
 862        one
 863        two
 864
 865        «<patch>
 866        <edit>
 867        <description>add a `two` function</description>
 868        <path>src/lib.rs</path>
 869        <operation>insert_after</operation>
 870        <old_text>fn zero</old_text>
 871        <new_text>
 872        fn two() {}
 873        </new_text>
 874        </edit>
 875        </patch>
 876        »
 877        also,",
 878        &[&[AssistantEdit {
 879            path: "src/lib.rs".into(),
 880            kind: AssistantEditKind::InsertAfter {
 881                old_text: "fn zero".into(),
 882                new_text: "fn two() {}".into(),
 883                description: Some("add a `two` function".into()),
 884            },
 885        }]],
 886        cx,
 887    );
 888
 889    // When setting the message role to User, the steps are cleared.
 890    context.update(cx, |context, cx| {
 891        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 892        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 893    });
 894    expect_patches(
 895        &context,
 896        "
 897
 898        one
 899        two
 900
 901        <patch>
 902        <edit>
 903        <description>add a `two` function</description>
 904        <path>src/lib.rs</path>
 905        <operation>insert_after</operation>
 906        <old_text>fn zero</old_text>
 907        <new_text>
 908        fn two() {}
 909        </new_text>
 910        </edit>
 911        </patch>
 912
 913        also,",
 914        &[],
 915        cx,
 916    );
 917
 918    // When setting the message role back to Assistant, the steps are reparsed.
 919    context.update(cx, |context, cx| {
 920        context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
 921    });
 922    expect_patches(
 923        &context,
 924        "
 925
 926        one
 927        two
 928
 929        «<patch>
 930        <edit>
 931        <description>add a `two` function</description>
 932        <path>src/lib.rs</path>
 933        <operation>insert_after</operation>
 934        <old_text>fn zero</old_text>
 935        <new_text>
 936        fn two() {}
 937        </new_text>
 938        </edit>
 939        </patch>
 940        »
 941        also,",
 942        &[&[AssistantEdit {
 943            path: "src/lib.rs".into(),
 944            kind: AssistantEditKind::InsertAfter {
 945                old_text: "fn zero".into(),
 946                new_text: "fn two() {}".into(),
 947                description: Some("add a `two` function".into()),
 948            },
 949        }]],
 950        cx,
 951    );
 952
 953    // Ensure steps are re-parsed when deserializing.
 954    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
 955    let deserialized_context = cx.new(|cx| {
 956        AssistantContext::deserialize(
 957            serialized_context,
 958            Path::new("").into(),
 959            registry.clone(),
 960            prompt_builder.clone(),
 961            Arc::new(SlashCommandWorkingSet::default()),
 962            None,
 963            None,
 964            cx,
 965        )
 966    });
 967    expect_patches(
 968        &deserialized_context,
 969        "
 970
 971        one
 972        two
 973
 974        «<patch>
 975        <edit>
 976        <description>add a `two` function</description>
 977        <path>src/lib.rs</path>
 978        <operation>insert_after</operation>
 979        <old_text>fn zero</old_text>
 980        <new_text>
 981        fn two() {}
 982        </new_text>
 983        </edit>
 984        </patch>
 985        »
 986        also,",
 987        &[&[AssistantEdit {
 988            path: "src/lib.rs".into(),
 989            kind: AssistantEditKind::InsertAfter {
 990                old_text: "fn zero".into(),
 991                new_text: "fn two() {}".into(),
 992                description: Some("add a `two` function".into()),
 993            },
 994        }]],
 995        cx,
 996    );
 997
 998    fn edit(
 999        context: &Entity<AssistantContext>,
1000        new_text_marked_with_edits: &str,
1001        cx: &mut TestAppContext,
1002    ) {
1003        context.update(cx, |context, cx| {
1004            context.buffer.update(cx, |buffer, cx| {
1005                buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1006            });
1007        });
1008        cx.executor().run_until_parked();
1009    }
1010
1011    #[track_caller]
1012    fn expect_patches(
1013        context: &Entity<AssistantContext>,
1014        expected_marked_text: &str,
1015        expected_suggestions: &[&[AssistantEdit]],
1016        cx: &mut TestAppContext,
1017    ) {
1018        let expected_marked_text = expected_marked_text.unindent();
1019        let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1020
1021        let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1022            context.buffer.read_with(cx, |buffer, _| {
1023                let ranges = context
1024                    .patches
1025                    .iter()
1026                    .map(|entry| entry.range.to_offset(buffer))
1027                    .collect::<Vec<_>>();
1028                (
1029                    buffer.text(),
1030                    ranges,
1031                    context
1032                        .patches
1033                        .iter()
1034                        .map(|step| step.edits.clone())
1035                        .collect::<Vec<_>>(),
1036                )
1037            })
1038        });
1039
1040        assert_eq!(buffer_text, expected_text);
1041
1042        let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1043        assert_eq!(actual_marked_text, expected_marked_text);
1044
1045        assert_eq!(
1046            patches
1047                .iter()
1048                .map(|patch| {
1049                    patch
1050                        .iter()
1051                        .map(|edit| {
1052                            let edit = edit.as_ref().unwrap();
1053                            AssistantEdit {
1054                                path: edit.path.clone(),
1055                                kind: edit.kind.clone(),
1056                            }
1057                        })
1058                        .collect::<Vec<_>>()
1059                })
1060                .collect::<Vec<_>>(),
1061            expected_suggestions
1062        );
1063    }
1064}
1065
1066#[gpui::test]
1067async fn test_serialization(cx: &mut TestAppContext) {
1068    cx.update(init_test);
1069
1070    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1071    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1072    let context = cx.new(|cx| {
1073        AssistantContext::local(
1074            registry.clone(),
1075            None,
1076            None,
1077            prompt_builder.clone(),
1078            Arc::new(SlashCommandWorkingSet::default()),
1079            cx,
1080        )
1081    });
1082    let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1083    let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1084    let message_1 = context.update(cx, |context, cx| {
1085        context
1086            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1087            .unwrap()
1088    });
1089    let message_2 = context.update(cx, |context, cx| {
1090        context
1091            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1092            .unwrap()
1093    });
1094    buffer.update(cx, |buffer, cx| {
1095        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1096        buffer.finalize_last_transaction();
1097    });
1098    let _message_3 = context.update(cx, |context, cx| {
1099        context
1100            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1101            .unwrap()
1102    });
1103    buffer.update(cx, |buffer, cx| buffer.undo(cx));
1104    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1105    assert_eq!(
1106        cx.read(|cx| messages(&context, cx)),
1107        [
1108            (message_0, Role::User, 0..2),
1109            (message_1.id, Role::Assistant, 2..6),
1110            (message_2.id, Role::System, 6..6),
1111        ]
1112    );
1113
1114    let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1115    let deserialized_context = cx.new(|cx| {
1116        AssistantContext::deserialize(
1117            serialized_context,
1118            Path::new("").into(),
1119            registry.clone(),
1120            prompt_builder.clone(),
1121            Arc::new(SlashCommandWorkingSet::default()),
1122            None,
1123            None,
1124            cx,
1125        )
1126    });
1127    let deserialized_buffer =
1128        deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1129    assert_eq!(
1130        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1131        "a\nb\nc\n"
1132    );
1133    assert_eq!(
1134        cx.read(|cx| messages(&deserialized_context, cx)),
1135        [
1136            (message_0, Role::User, 0..2),
1137            (message_1.id, Role::Assistant, 2..6),
1138            (message_2.id, Role::System, 6..6),
1139        ]
1140    );
1141}
1142
1143#[gpui::test(iterations = 100)]
1144async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1145    cx.update(init_test);
1146
1147    let min_peers = env::var("MIN_PEERS")
1148        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1149        .unwrap_or(2);
1150    let max_peers = env::var("MAX_PEERS")
1151        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1152        .unwrap_or(5);
1153    let operations = env::var("OPERATIONS")
1154        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1155        .unwrap_or(50);
1156
1157    let slash_commands = cx.update(SlashCommandRegistry::default_global);
1158    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1159    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1160    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1161
1162    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1163    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1164    let mut contexts = Vec::new();
1165
1166    let num_peers = rng.gen_range(min_peers..=max_peers);
1167    let context_id = ContextId::new();
1168    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1169    for i in 0..num_peers {
1170        let context = cx.new(|cx| {
1171            AssistantContext::new(
1172                context_id.clone(),
1173                i as ReplicaId,
1174                language::Capability::ReadWrite,
1175                registry.clone(),
1176                prompt_builder.clone(),
1177                Arc::new(SlashCommandWorkingSet::default()),
1178                None,
1179                None,
1180                cx,
1181            )
1182        });
1183
1184        cx.update(|cx| {
1185            cx.subscribe(&context, {
1186                let network = network.clone();
1187                move |_, event, _| {
1188                    if let ContextEvent::Operation(op) = event {
1189                        network
1190                            .lock()
1191                            .broadcast(i as ReplicaId, vec![op.to_proto()]);
1192                    }
1193                }
1194            })
1195            .detach();
1196        });
1197
1198        contexts.push(context);
1199        network.lock().add_peer(i as ReplicaId);
1200    }
1201
1202    let mut mutation_count = operations;
1203
1204    while mutation_count > 0
1205        || !network.lock().is_idle()
1206        || network.lock().contains_disconnected_peers()
1207    {
1208        let context_index = rng.gen_range(0..contexts.len());
1209        let context = &contexts[context_index];
1210
1211        match rng.gen_range(0..100) {
1212            0..=29 if mutation_count > 0 => {
1213                log::info!("Context {}: edit buffer", context_index);
1214                context.update(cx, |context, cx| {
1215                    context
1216                        .buffer
1217                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1218                });
1219                mutation_count -= 1;
1220            }
1221            30..=44 if mutation_count > 0 => {
1222                context.update(cx, |context, cx| {
1223                    let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1224                    log::info!("Context {}: split message at {:?}", context_index, range);
1225                    context.split_message(range, cx);
1226                });
1227                mutation_count -= 1;
1228            }
1229            45..=59 if mutation_count > 0 => {
1230                context.update(cx, |context, cx| {
1231                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1232                        let role = *[Role::User, Role::Assistant, Role::System]
1233                            .choose(&mut rng)
1234                            .unwrap();
1235                        log::info!(
1236                            "Context {}: insert message after {:?} with {:?}",
1237                            context_index,
1238                            message.id,
1239                            role
1240                        );
1241                        context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1242                    }
1243                });
1244                mutation_count -= 1;
1245            }
1246            60..=74 if mutation_count > 0 => {
1247                context.update(cx, |context, cx| {
1248                    let command_text = "/".to_string()
1249                        + slash_commands
1250                            .command_names()
1251                            .choose(&mut rng)
1252                            .unwrap()
1253                            .clone()
1254                            .as_ref();
1255
1256                    let command_range = context.buffer.update(cx, |buffer, cx| {
1257                        let offset = buffer.random_byte_range(0, &mut rng).start;
1258                        buffer.edit(
1259                            [(offset..offset, format!("\n{}\n", command_text))],
1260                            None,
1261                            cx,
1262                        );
1263                        offset + 1..offset + 1 + command_text.len()
1264                    });
1265
1266                    let output_text = RandomCharIter::new(&mut rng)
1267                        .filter(|c| *c != '\r')
1268                        .take(10)
1269                        .collect::<String>();
1270
1271                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1272                        role: Role::User,
1273                        merge_same_roles: true,
1274                    })];
1275
1276                    let num_sections = rng.gen_range(0..=3);
1277                    let mut section_start = 0;
1278                    for _ in 0..num_sections {
1279                        let mut section_end = rng.gen_range(section_start..=output_text.len());
1280                        while !output_text.is_char_boundary(section_end) {
1281                            section_end += 1;
1282                        }
1283                        events.push(Ok(SlashCommandEvent::StartSection {
1284                            icon: IconName::Ai,
1285                            label: "section".into(),
1286                            metadata: None,
1287                        }));
1288                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1289                            text: output_text[section_start..section_end].to_string(),
1290                            run_commands_in_text: false,
1291                        })));
1292                        events.push(Ok(SlashCommandEvent::EndSection));
1293                        section_start = section_end;
1294                    }
1295
1296                    if section_start < output_text.len() {
1297                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1298                            text: output_text[section_start..].to_string(),
1299                            run_commands_in_text: false,
1300                        })));
1301                    }
1302
1303                    log::info!(
1304                        "Context {}: insert slash command output at {:?} with {:?} events",
1305                        context_index,
1306                        command_range,
1307                        events.len()
1308                    );
1309
1310                    let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1311                        ..context.buffer.read(cx).anchor_after(command_range.end);
1312                    context.insert_command_output(
1313                        command_range,
1314                        "/command",
1315                        Task::ready(Ok(stream::iter(events).boxed())),
1316                        true,
1317                        cx,
1318                    );
1319                });
1320                cx.run_until_parked();
1321                mutation_count -= 1;
1322            }
1323            75..=84 if mutation_count > 0 => {
1324                context.update(cx, |context, cx| {
1325                    if let Some(message) = context.messages(cx).choose(&mut rng) {
1326                        let new_status = match rng.gen_range(0..3) {
1327                            0 => MessageStatus::Done,
1328                            1 => MessageStatus::Pending,
1329                            _ => MessageStatus::Error(SharedString::from("Random error")),
1330                        };
1331                        log::info!(
1332                            "Context {}: update message {:?} status to {:?}",
1333                            context_index,
1334                            message.id,
1335                            new_status
1336                        );
1337                        context.update_metadata(message.id, cx, |metadata| {
1338                            metadata.status = new_status;
1339                        });
1340                    }
1341                });
1342                mutation_count -= 1;
1343            }
1344            _ => {
1345                let replica_id = context_index as ReplicaId;
1346                if network.lock().is_disconnected(replica_id) {
1347                    network.lock().reconnect_peer(replica_id, 0);
1348
1349                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1350                        let host_context = &contexts[0].read(cx);
1351                        let guest_context = context.read(cx);
1352                        (
1353                            guest_context.serialize_ops(&host_context.version(cx), cx),
1354                            host_context.serialize_ops(&guest_context.version(cx), cx),
1355                        )
1356                    });
1357                    let ops_to_send = ops_to_send.await;
1358                    let ops_to_receive = ops_to_receive
1359                        .await
1360                        .into_iter()
1361                        .map(ContextOperation::from_proto)
1362                        .collect::<Result<Vec<_>>>()
1363                        .unwrap();
1364                    log::info!(
1365                        "Context {}: reconnecting. Sent {} operations, received {} operations",
1366                        context_index,
1367                        ops_to_send.len(),
1368                        ops_to_receive.len()
1369                    );
1370
1371                    network.lock().broadcast(replica_id, ops_to_send);
1372                    context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1373                } else if rng.gen_bool(0.1) && replica_id != 0 {
1374                    log::info!("Context {}: disconnecting", context_index);
1375                    network.lock().disconnect_peer(replica_id);
1376                } else if network.lock().has_unreceived(replica_id) {
1377                    log::info!("Context {}: applying operations", context_index);
1378                    let ops = network.lock().receive(replica_id);
1379                    let ops = ops
1380                        .into_iter()
1381                        .map(ContextOperation::from_proto)
1382                        .collect::<Result<Vec<_>>>()
1383                        .unwrap();
1384                    context.update(cx, |context, cx| context.apply_ops(ops, cx));
1385                }
1386            }
1387        }
1388    }
1389
1390    cx.read(|cx| {
1391        let first_context = contexts[0].read(cx);
1392        for context in &contexts[1..] {
1393            let context = context.read(cx);
1394            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1395            assert_eq!(
1396                context.buffer.read(cx).text(),
1397                first_context.buffer.read(cx).text(),
1398                "Context {} text != Context 0 text",
1399                context.buffer.read(cx).replica_id()
1400            );
1401            assert_eq!(
1402                context.message_anchors,
1403                first_context.message_anchors,
1404                "Context {} messages != Context 0 messages",
1405                context.buffer.read(cx).replica_id()
1406            );
1407            assert_eq!(
1408                context.messages_metadata,
1409                first_context.messages_metadata,
1410                "Context {} message metadata != Context 0 message metadata",
1411                context.buffer.read(cx).replica_id()
1412            );
1413            assert_eq!(
1414                context.slash_command_output_sections,
1415                first_context.slash_command_output_sections,
1416                "Context {} slash command output sections != Context 0 slash command output sections",
1417                context.buffer.read(cx).replica_id()
1418            );
1419        }
1420    });
1421}
1422
1423#[gpui::test]
1424fn test_mark_cache_anchors(cx: &mut App) {
1425    init_test(cx);
1426
1427    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1428    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1429    let context = cx.new(|cx| {
1430        AssistantContext::local(
1431            registry,
1432            None,
1433            None,
1434            prompt_builder.clone(),
1435            Arc::new(SlashCommandWorkingSet::default()),
1436            cx,
1437        )
1438    });
1439    let buffer = context.read(cx).buffer.clone();
1440
1441    // Create a test cache configuration
1442    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1443        max_cache_anchors: 3,
1444        should_speculate: true,
1445        min_total_token: 10,
1446    });
1447
1448    let message_1 = context.read(cx).message_anchors[0].clone();
1449
1450    context.update(cx, |context, cx| {
1451        context.mark_cache_anchors(cache_configuration, false, cx)
1452    });
1453
1454    assert_eq!(
1455        messages_cache(&context, cx)
1456            .iter()
1457            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1458            .count(),
1459        0,
1460        "Empty messages should not have any cache anchors."
1461    );
1462
1463    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1464    let message_2 = context
1465        .update(cx, |context, cx| {
1466            context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1467        })
1468        .unwrap();
1469
1470    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1471    let message_3 = context
1472        .update(cx, |context, cx| {
1473            context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1474        })
1475        .unwrap();
1476    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1477
1478    context.update(cx, |context, cx| {
1479        context.mark_cache_anchors(cache_configuration, false, cx)
1480    });
1481    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1482    assert_eq!(
1483        messages_cache(&context, cx)
1484            .iter()
1485            .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1486            .count(),
1487        0,
1488        "Messages should not be marked for cache before going over the token minimum."
1489    );
1490    context.update(cx, |context, _| {
1491        context.token_count = Some(20);
1492    });
1493
1494    context.update(cx, |context, cx| {
1495        context.mark_cache_anchors(cache_configuration, true, cx)
1496    });
1497    assert_eq!(
1498        messages_cache(&context, cx)
1499            .iter()
1500            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1501            .collect::<Vec<bool>>(),
1502        vec![true, true, false],
1503        "Last message should not be an anchor on speculative request."
1504    );
1505
1506    context
1507        .update(cx, |context, cx| {
1508            context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1509        })
1510        .unwrap();
1511
1512    context.update(cx, |context, cx| {
1513        context.mark_cache_anchors(cache_configuration, false, cx)
1514    });
1515    assert_eq!(
1516        messages_cache(&context, cx)
1517            .iter()
1518            .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1519            .collect::<Vec<bool>>(),
1520        vec![false, true, true, false],
1521        "Most recent message should also be cached if not a speculative request."
1522    );
1523    context.update(cx, |context, cx| {
1524        context.update_cache_status_for_completion(cx)
1525    });
1526    assert_eq!(
1527        messages_cache(&context, cx)
1528            .iter()
1529            .map(|(_, cache)| cache
1530                .as_ref()
1531                .map_or(None, |cache| Some(cache.status.clone())))
1532            .collect::<Vec<Option<CacheStatus>>>(),
1533        vec![
1534            Some(CacheStatus::Cached),
1535            Some(CacheStatus::Cached),
1536            Some(CacheStatus::Cached),
1537            None
1538        ],
1539        "All user messages prior to anchor should be marked as cached."
1540    );
1541
1542    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1543    context.update(cx, |context, cx| {
1544        context.mark_cache_anchors(cache_configuration, false, cx)
1545    });
1546    assert_eq!(
1547        messages_cache(&context, cx)
1548            .iter()
1549            .map(|(_, cache)| cache
1550                .as_ref()
1551                .map_or(None, |cache| Some(cache.status.clone())))
1552            .collect::<Vec<Option<CacheStatus>>>(),
1553        vec![
1554            Some(CacheStatus::Cached),
1555            Some(CacheStatus::Cached),
1556            Some(CacheStatus::Pending),
1557            None
1558        ],
1559        "Modifying a message should invalidate it's cache but leave previous messages."
1560    );
1561    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1562    context.update(cx, |context, cx| {
1563        context.mark_cache_anchors(cache_configuration, false, cx)
1564    });
1565    assert_eq!(
1566        messages_cache(&context, cx)
1567            .iter()
1568            .map(|(_, cache)| cache
1569                .as_ref()
1570                .map_or(None, |cache| Some(cache.status.clone())))
1571            .collect::<Vec<Option<CacheStatus>>>(),
1572        vec![
1573            Some(CacheStatus::Pending),
1574            Some(CacheStatus::Pending),
1575            Some(CacheStatus::Pending),
1576            None
1577        ],
1578        "Modifying a message should invalidate all future messages."
1579    );
1580}
1581
1582#[gpui::test]
1583async fn test_summarization(cx: &mut TestAppContext) {
1584    let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1585
1586    // Initial state should be pending
1587    context.read_with(cx, |context, _| {
1588        assert!(matches!(context.summary(), ContextSummary::Pending));
1589        assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
1590    });
1591
1592    let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
1593    context.update(cx, |context, cx| {
1594        context
1595            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1596            .unwrap();
1597    });
1598
1599    // Send a message
1600    context.update(cx, |context, cx| {
1601        context.assist(RequestType::Chat, cx);
1602    });
1603
1604    simulate_successful_response(&fake_model, cx);
1605
1606    // Should start generating summary when there are >= 2 messages
1607    context.read_with(cx, |context, _| {
1608        assert!(!context.summary().content().unwrap().done);
1609    });
1610
1611    cx.run_until_parked();
1612    fake_model.stream_last_completion_response("Brief".into());
1613    fake_model.stream_last_completion_response(" Introduction".into());
1614    fake_model.end_last_completion_stream();
1615    cx.run_until_parked();
1616
1617    // Summary should be set
1618    context.read_with(cx, |context, _| {
1619        assert_eq!(context.summary().or_default(), "Brief Introduction");
1620    });
1621
1622    // We should be able to manually set a summary
1623    context.update(cx, |context, cx| {
1624        context.set_custom_summary("Brief Intro".into(), cx);
1625    });
1626
1627    context.read_with(cx, |context, _| {
1628        assert_eq!(context.summary().or_default(), "Brief Intro");
1629    });
1630}
1631
1632#[gpui::test]
1633async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1634    let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1635
1636    test_summarize_error(&fake_model, &context, cx);
1637
1638    // Now we should be able to set a summary
1639    context.update(cx, |context, cx| {
1640        context.set_custom_summary("Brief Intro".into(), cx);
1641    });
1642
1643    context.read_with(cx, |context, _| {
1644        assert_eq!(context.summary().or_default(), "Brief Intro");
1645    });
1646}
1647
1648#[gpui::test]
1649async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1650    let (context, fake_model) = setup_context_editor_with_fake_model(cx);
1651
1652    test_summarize_error(&fake_model, &context, cx);
1653
1654    // Sending another message should not trigger another summarize request
1655    context.update(cx, |context, cx| {
1656        context.assist(RequestType::Chat, cx);
1657    });
1658
1659    simulate_successful_response(&fake_model, cx);
1660
1661    context.read_with(cx, |context, _| {
1662        // State is still Error, not Generating
1663        assert!(matches!(context.summary(), ContextSummary::Error));
1664    });
1665
1666    // But the summarize request can be invoked manually
1667    context.update(cx, |context, cx| {
1668        context.summarize(true, cx);
1669    });
1670
1671    context.read_with(cx, |context, _| {
1672        assert!(!context.summary().content().unwrap().done);
1673    });
1674
1675    cx.run_until_parked();
1676    fake_model.stream_last_completion_response("A successful summary".into());
1677    fake_model.end_last_completion_stream();
1678    cx.run_until_parked();
1679
1680    context.read_with(cx, |context, _| {
1681        assert_eq!(context.summary().or_default(), "A successful summary");
1682    });
1683}
1684
1685fn test_summarize_error(
1686    model: &Arc<FakeLanguageModel>,
1687    context: &Entity<AssistantContext>,
1688    cx: &mut TestAppContext,
1689) {
1690    let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
1691    context.update(cx, |context, cx| {
1692        context
1693            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1694            .unwrap();
1695    });
1696
1697    // Send a message
1698    context.update(cx, |context, cx| {
1699        context.assist(RequestType::Chat, cx);
1700    });
1701
1702    simulate_successful_response(&model, cx);
1703
1704    context.read_with(cx, |context, _| {
1705        assert!(!context.summary().content().unwrap().done);
1706    });
1707
1708    // Simulate summary request ending
1709    cx.run_until_parked();
1710    model.end_last_completion_stream();
1711    cx.run_until_parked();
1712
1713    // State is set to Error and default message
1714    context.read_with(cx, |context, _| {
1715        assert_eq!(*context.summary(), ContextSummary::Error);
1716        assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
1717    });
1718}
1719
1720fn setup_context_editor_with_fake_model(
1721    cx: &mut TestAppContext,
1722) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
1723    let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
1724
1725    let fake_provider = Arc::new(FakeLanguageModelProvider);
1726    let fake_model = Arc::new(fake_provider.test_model());
1727
1728    cx.update(|cx| {
1729        init_test(cx);
1730        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1731            registry.set_default_model(
1732                Some(ConfiguredModel {
1733                    provider: fake_provider.clone(),
1734                    model: fake_model.clone(),
1735                }),
1736                cx,
1737            )
1738        })
1739    });
1740
1741    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1742    let context = cx.new(|cx| {
1743        AssistantContext::local(
1744            registry,
1745            None,
1746            None,
1747            prompt_builder.clone(),
1748            Arc::new(SlashCommandWorkingSet::default()),
1749            cx,
1750        )
1751    });
1752
1753    (context, fake_model)
1754}
1755
1756fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1757    cx.run_until_parked();
1758    fake_model.stream_last_completion_response("Assistant response".into());
1759    fake_model.end_last_completion_stream();
1760    cx.run_until_parked();
1761}
1762
1763fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1764    context
1765        .read(cx)
1766        .messages(cx)
1767        .map(|message| (message.id, message.role, message.offset_range))
1768        .collect()
1769}
1770
1771fn messages_cache(
1772    context: &Entity<AssistantContext>,
1773    cx: &App,
1774) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1775    context
1776        .read(cx)
1777        .messages(cx)
1778        .map(|message| (message.id, message.cache.clone()))
1779        .collect()
1780}
1781
1782fn init_test(cx: &mut App) {
1783    let settings_store = SettingsStore::test(cx);
1784    prompt_store::init(cx);
1785    LanguageModelRegistry::test(cx);
1786    cx.set_global(settings_store);
1787    language::init(cx);
1788    assistant_settings::init(cx);
1789    Project::init_settings(cx);
1790}
1791
1792#[derive(Clone)]
1793struct FakeSlashCommand(String);
1794
1795impl SlashCommand for FakeSlashCommand {
1796    fn name(&self) -> String {
1797        self.0.clone()
1798    }
1799
1800    fn description(&self) -> String {
1801        format!("Fake slash command: {}", self.0)
1802    }
1803
1804    fn menu_text(&self) -> String {
1805        format!("Run fake command: {}", self.0)
1806    }
1807
1808    fn complete_argument(
1809        self: Arc<Self>,
1810        _arguments: &[String],
1811        _cancel: Arc<AtomicBool>,
1812        _workspace: Option<WeakEntity<Workspace>>,
1813        _window: &mut Window,
1814        _cx: &mut App,
1815    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1816        Task::ready(Ok(vec![]))
1817    }
1818
1819    fn requires_argument(&self) -> bool {
1820        false
1821    }
1822
1823    fn run(
1824        self: Arc<Self>,
1825        _arguments: &[String],
1826        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1827        _context_buffer: BufferSnapshot,
1828        _workspace: WeakEntity<Workspace>,
1829        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1830        _window: &mut Window,
1831        _cx: &mut App,
1832    ) -> Task<SlashCommandResult> {
1833        Task::ready(Ok(SlashCommandOutput {
1834            text: format!("Executed fake command: {}", self.0),
1835            sections: vec![],
1836            run_commands_in_text: false,
1837        }
1838        .to_event_stream()))
1839    }
1840}