assistant_text_thread_tests.rs

   1use crate::{
   2    CacheStatus, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, TextThread,
   3    TextThreadEvent, TextThreadId, TextThreadOperation, TextThreadSummary,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{
  20    ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
  21    fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
  22};
  23use parking_lot::Mutex;
  24use pretty_assertions::assert_eq;
  25use prompt_store::PromptBuilder;
  26use rand::prelude::*;
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{
  30    cell::RefCell,
  31    env,
  32    ops::Range,
  33    path::Path,
  34    rc::Rc,
  35    sync::{Arc, atomic::AtomicBool},
  36};
  37use text::{ReplicaId, ToOffset, network::Network};
  38use ui::{IconName, Window};
  39use unindent::Unindent;
  40use util::RandomCharIter;
  41use workspace::Workspace;
  42
  43#[gpui::test]
  44fn test_inserting_and_removing_messages(cx: &mut App) {
  45    init_test(cx);
  46
  47    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  48    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  49    let text_thread = cx.new(|cx| {
  50        TextThread::local(
  51            registry,
  52            None,
  53            prompt_builder.clone(),
  54            Arc::new(SlashCommandWorkingSet::default()),
  55            cx,
  56        )
  57    });
  58    let buffer = text_thread.read(cx).buffer().clone();
  59
  60    let message_1 = text_thread.read(cx).message_anchors[0].clone();
  61    assert_eq!(
  62        messages(&text_thread, cx),
  63        vec![(message_1.id, Role::User, 0..0)]
  64    );
  65
  66    let message_2 = text_thread.update(cx, |context, cx| {
  67        context
  68            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  69            .unwrap()
  70    });
  71    assert_eq!(
  72        messages(&text_thread, cx),
  73        vec![
  74            (message_1.id, Role::User, 0..1),
  75            (message_2.id, Role::Assistant, 1..1)
  76        ]
  77    );
  78
  79    buffer.update(cx, |buffer, cx| {
  80        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  81    });
  82    assert_eq!(
  83        messages(&text_thread, cx),
  84        vec![
  85            (message_1.id, Role::User, 0..2),
  86            (message_2.id, Role::Assistant, 2..3)
  87        ]
  88    );
  89
  90    let message_3 = text_thread.update(cx, |context, cx| {
  91        context
  92            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  93            .unwrap()
  94    });
  95    assert_eq!(
  96        messages(&text_thread, cx),
  97        vec![
  98            (message_1.id, Role::User, 0..2),
  99            (message_2.id, Role::Assistant, 2..4),
 100            (message_3.id, Role::User, 4..4)
 101        ]
 102    );
 103
 104    let message_4 = text_thread.update(cx, |context, cx| {
 105        context
 106            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 107            .unwrap()
 108    });
 109    assert_eq!(
 110        messages(&text_thread, cx),
 111        vec![
 112            (message_1.id, Role::User, 0..2),
 113            (message_2.id, Role::Assistant, 2..4),
 114            (message_4.id, Role::User, 4..5),
 115            (message_3.id, Role::User, 5..5),
 116        ]
 117    );
 118
 119    buffer.update(cx, |buffer, cx| {
 120        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 121    });
 122    assert_eq!(
 123        messages(&text_thread, cx),
 124        vec![
 125            (message_1.id, Role::User, 0..2),
 126            (message_2.id, Role::Assistant, 2..4),
 127            (message_4.id, Role::User, 4..6),
 128            (message_3.id, Role::User, 6..7),
 129        ]
 130    );
 131
 132    // Deleting across message boundaries merges the messages.
 133    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 134    assert_eq!(
 135        messages(&text_thread, cx),
 136        vec![
 137            (message_1.id, Role::User, 0..3),
 138            (message_3.id, Role::User, 3..4),
 139        ]
 140    );
 141
 142    // Undoing the deletion should also undo the merge.
 143    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 144    assert_eq!(
 145        messages(&text_thread, cx),
 146        vec![
 147            (message_1.id, Role::User, 0..2),
 148            (message_2.id, Role::Assistant, 2..4),
 149            (message_4.id, Role::User, 4..6),
 150            (message_3.id, Role::User, 6..7),
 151        ]
 152    );
 153
 154    // Redoing the deletion should also redo the merge.
 155    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 156    assert_eq!(
 157        messages(&text_thread, cx),
 158        vec![
 159            (message_1.id, Role::User, 0..3),
 160            (message_3.id, Role::User, 3..4),
 161        ]
 162    );
 163
 164    // Ensure we can still insert after a merged message.
 165    let message_5 = text_thread.update(cx, |context, cx| {
 166        context
 167            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 168            .unwrap()
 169    });
 170    assert_eq!(
 171        messages(&text_thread, cx),
 172        vec![
 173            (message_1.id, Role::User, 0..3),
 174            (message_5.id, Role::System, 3..4),
 175            (message_3.id, Role::User, 4..5)
 176        ]
 177    );
 178}
 179
 180#[gpui::test]
 181fn test_message_splitting(cx: &mut App) {
 182    init_test(cx);
 183
 184    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 185
 186    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 187    let text_thread = cx.new(|cx| {
 188        TextThread::local(
 189            registry.clone(),
 190            None,
 191            prompt_builder.clone(),
 192            Arc::new(SlashCommandWorkingSet::default()),
 193            cx,
 194        )
 195    });
 196    let buffer = text_thread.read(cx).buffer().clone();
 197
 198    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 199    assert_eq!(
 200        messages(&text_thread, cx),
 201        vec![(message_1.id, Role::User, 0..0)]
 202    );
 203
 204    buffer.update(cx, |buffer, cx| {
 205        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 206    });
 207
 208    let (_, message_2) =
 209        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 210    let message_2 = message_2.unwrap();
 211
 212    // We recycle newlines in the middle of a split message
 213    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 214    assert_eq!(
 215        messages(&text_thread, cx),
 216        vec![
 217            (message_1.id, Role::User, 0..4),
 218            (message_2.id, Role::User, 4..16),
 219        ]
 220    );
 221
 222    let (_, message_3) =
 223        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 224    let message_3 = message_3.unwrap();
 225
 226    // We don't recycle newlines at the end of a split message
 227    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 228    assert_eq!(
 229        messages(&text_thread, cx),
 230        vec![
 231            (message_1.id, Role::User, 0..4),
 232            (message_3.id, Role::User, 4..5),
 233            (message_2.id, Role::User, 5..17),
 234        ]
 235    );
 236
 237    let (_, message_4) =
 238        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 239    let message_4 = message_4.unwrap();
 240    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 241    assert_eq!(
 242        messages(&text_thread, cx),
 243        vec![
 244            (message_1.id, Role::User, 0..4),
 245            (message_3.id, Role::User, 4..5),
 246            (message_2.id, Role::User, 5..9),
 247            (message_4.id, Role::User, 9..17),
 248        ]
 249    );
 250
 251    let (_, message_5) =
 252        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 253    let message_5 = message_5.unwrap();
 254    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 255    assert_eq!(
 256        messages(&text_thread, cx),
 257        vec![
 258            (message_1.id, Role::User, 0..4),
 259            (message_3.id, Role::User, 4..5),
 260            (message_2.id, Role::User, 5..9),
 261            (message_4.id, Role::User, 9..10),
 262            (message_5.id, Role::User, 10..18),
 263        ]
 264    );
 265
 266    let (message_6, message_7) =
 267        text_thread.update(cx, |text_thread, cx| text_thread.split_message(14..16, cx));
 268    let message_6 = message_6.unwrap();
 269    let message_7 = message_7.unwrap();
 270    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 271    assert_eq!(
 272        messages(&text_thread, cx),
 273        vec![
 274            (message_1.id, Role::User, 0..4),
 275            (message_3.id, Role::User, 4..5),
 276            (message_2.id, Role::User, 5..9),
 277            (message_4.id, Role::User, 9..10),
 278            (message_5.id, Role::User, 10..14),
 279            (message_6.id, Role::User, 14..17),
 280            (message_7.id, Role::User, 17..19),
 281        ]
 282    );
 283}
 284
 285#[gpui::test]
 286fn test_messages_for_offsets(cx: &mut App) {
 287    init_test(cx);
 288
 289    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 290    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 291    let text_thread = cx.new(|cx| {
 292        TextThread::local(
 293            registry,
 294            None,
 295            prompt_builder.clone(),
 296            Arc::new(SlashCommandWorkingSet::default()),
 297            cx,
 298        )
 299    });
 300    let buffer = text_thread.read(cx).buffer().clone();
 301
 302    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 303    assert_eq!(
 304        messages(&text_thread, cx),
 305        vec![(message_1.id, Role::User, 0..0)]
 306    );
 307
 308    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 309    let message_2 = text_thread
 310        .update(cx, |text_thread, cx| {
 311            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 312        })
 313        .unwrap();
 314    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 315
 316    let message_3 = text_thread
 317        .update(cx, |text_thread, cx| {
 318            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 319        })
 320        .unwrap();
 321    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 322
 323    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 324    assert_eq!(
 325        messages(&text_thread, cx),
 326        vec![
 327            (message_1.id, Role::User, 0..4),
 328            (message_2.id, Role::User, 4..8),
 329            (message_3.id, Role::User, 8..11)
 330        ]
 331    );
 332
 333    assert_eq!(
 334        message_ids_for_offsets(&text_thread, &[0, 4, 9], cx),
 335        [message_1.id, message_2.id, message_3.id]
 336    );
 337    assert_eq!(
 338        message_ids_for_offsets(&text_thread, &[0, 1, 11], cx),
 339        [message_1.id, message_3.id]
 340    );
 341
 342    let message_4 = text_thread
 343        .update(cx, |text_thread, cx| {
 344            text_thread.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 345        })
 346        .unwrap();
 347    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 348    assert_eq!(
 349        messages(&text_thread, cx),
 350        vec![
 351            (message_1.id, Role::User, 0..4),
 352            (message_2.id, Role::User, 4..8),
 353            (message_3.id, Role::User, 8..12),
 354            (message_4.id, Role::User, 12..12)
 355        ]
 356    );
 357    assert_eq!(
 358        message_ids_for_offsets(&text_thread, &[0, 4, 8, 12], cx),
 359        [message_1.id, message_2.id, message_3.id, message_4.id]
 360    );
 361
 362    fn message_ids_for_offsets(
 363        context: &Entity<TextThread>,
 364        offsets: &[usize],
 365        cx: &App,
 366    ) -> Vec<MessageId> {
 367        context
 368            .read(cx)
 369            .messages_for_offsets(offsets.iter().copied(), cx)
 370            .into_iter()
 371            .map(|message| message.id)
 372            .collect()
 373    }
 374}
 375
 376#[gpui::test]
 377async fn test_slash_commands(cx: &mut TestAppContext) {
 378    cx.update(init_test);
 379
 380    let fs = FakeFs::new(cx.background_executor.clone());
 381
 382    fs.insert_tree(
 383        "/test",
 384        json!({
 385            "src": {
 386                "lib.rs": "fn one() -> usize { 1 }",
 387                "main.rs": "
 388                    use crate::one;
 389                    fn main() { one(); }
 390                ".unindent(),
 391            }
 392        }),
 393    )
 394    .await;
 395
 396    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 397    slash_command_registry.register_command(FileSlashCommand, false);
 398
 399    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 400    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 401    let text_thread = cx.new(|cx| {
 402        TextThread::local(
 403            registry.clone(),
 404            None,
 405            prompt_builder.clone(),
 406            Arc::new(SlashCommandWorkingSet::default()),
 407            cx,
 408        )
 409    });
 410
 411    #[derive(Default)]
 412    struct ContextRanges {
 413        parsed_commands: HashSet<Range<language::Anchor>>,
 414        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 415        output_sections: HashSet<Range<language::Anchor>>,
 416    }
 417
 418    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 419    text_thread.update(cx, |_, cx| {
 420        cx.subscribe(&text_thread, {
 421            let context_ranges = context_ranges.clone();
 422            move |text_thread, _, event, _| {
 423                let mut context_ranges = context_ranges.borrow_mut();
 424                match event {
 425                    TextThreadEvent::InvokedSlashCommandChanged { command_id } => {
 426                        let command = text_thread.invoked_slash_command(command_id).unwrap();
 427                        context_ranges
 428                            .command_outputs
 429                            .insert(*command_id, command.range.clone());
 430                    }
 431                    TextThreadEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 432                        for range in removed {
 433                            context_ranges.parsed_commands.remove(range);
 434                        }
 435                        for command in updated {
 436                            context_ranges
 437                                .parsed_commands
 438                                .insert(command.source_range.clone());
 439                        }
 440                    }
 441                    TextThreadEvent::SlashCommandOutputSectionAdded { section } => {
 442                        context_ranges.output_sections.insert(section.range.clone());
 443                    }
 444                    _ => {}
 445                }
 446            }
 447        })
 448        .detach();
 449    });
 450
 451    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 452
 453    // Insert a slash command
 454    buffer.update(cx, |buffer, cx| {
 455        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 456    });
 457    assert_text_and_context_ranges(
 458        &buffer,
 459        &context_ranges,
 460        &"
 461        «/file src/lib.rs»"
 462            .unindent(),
 463        cx,
 464    );
 465
 466    // Edit the argument of the slash command.
 467    buffer.update(cx, |buffer, cx| {
 468        let edit_offset = buffer.text().find("lib.rs").unwrap();
 469        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 470    });
 471    assert_text_and_context_ranges(
 472        &buffer,
 473        &context_ranges,
 474        &"
 475        «/file src/main.rs»"
 476            .unindent(),
 477        cx,
 478    );
 479
 480    // Edit the name of the slash command, using one that doesn't exist.
 481    buffer.update(cx, |buffer, cx| {
 482        let edit_offset = buffer.text().find("/file").unwrap();
 483        buffer.edit(
 484            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 485            None,
 486            cx,
 487        );
 488    });
 489    assert_text_and_context_ranges(
 490        &buffer,
 491        &context_ranges,
 492        &"
 493        /unknown src/main.rs"
 494            .unindent(),
 495        cx,
 496    );
 497
 498    // Undoing the insertion of an non-existent slash command resorts the previous one.
 499    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 500    assert_text_and_context_ranges(
 501        &buffer,
 502        &context_ranges,
 503        &"
 504        «/file src/main.rs»"
 505            .unindent(),
 506        cx,
 507    );
 508
 509    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 510    text_thread.update(cx, |text_thread, cx| {
 511        let command_source_range = text_thread.parsed_slash_commands[0].source_range.clone();
 512        text_thread.insert_command_output(
 513            command_source_range,
 514            "file",
 515            Task::ready(Ok(command_output_rx.boxed())),
 516            true,
 517            cx,
 518        );
 519    });
 520    assert_text_and_context_ranges(
 521        &buffer,
 522        &context_ranges,
 523        &"
 524        ⟦«/file src/main.rs»
 525        …⟧
 526        "
 527        .unindent(),
 528        cx,
 529    );
 530
 531    command_output_tx
 532        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 533            icon: IconName::Ai,
 534            label: "src/main.rs".into(),
 535            metadata: None,
 536        }))
 537        .unwrap();
 538    command_output_tx
 539        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 540        .unwrap();
 541    cx.run_until_parked();
 542    assert_text_and_context_ranges(
 543        &buffer,
 544        &context_ranges,
 545        &"
 546        ⟦«/file src/main.rs»
 547        src/main.rs…⟧
 548        "
 549        .unindent(),
 550        cx,
 551    );
 552
 553    command_output_tx
 554        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 555        .unwrap();
 556    cx.run_until_parked();
 557    assert_text_and_context_ranges(
 558        &buffer,
 559        &context_ranges,
 560        &"
 561        ⟦«/file src/main.rs»
 562        src/main.rs
 563        fn main() {}…⟧
 564        "
 565        .unindent(),
 566        cx,
 567    );
 568
 569    command_output_tx
 570        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 571        .unwrap();
 572    cx.run_until_parked();
 573    assert_text_and_context_ranges(
 574        &buffer,
 575        &context_ranges,
 576        &"
 577        ⟦«/file src/main.rs»
 578        ⟪src/main.rs
 579        fn main() {}⟫…⟧
 580        "
 581        .unindent(),
 582        cx,
 583    );
 584
 585    drop(command_output_tx);
 586    cx.run_until_parked();
 587    assert_text_and_context_ranges(
 588        &buffer,
 589        &context_ranges,
 590        &"
 591        ⟦⟪src/main.rs
 592        fn main() {}⟫⟧
 593        "
 594        .unindent(),
 595        cx,
 596    );
 597
 598    #[track_caller]
 599    fn assert_text_and_context_ranges(
 600        buffer: &Entity<Buffer>,
 601        ranges: &RefCell<ContextRanges>,
 602        expected_marked_text: &str,
 603        cx: &mut TestAppContext,
 604    ) {
 605        let mut actual_marked_text = String::new();
 606        buffer.update(cx, |buffer, _| {
 607            struct Endpoint {
 608                offset: usize,
 609                marker: char,
 610            }
 611
 612            let ranges = ranges.borrow();
 613            let mut endpoints = Vec::new();
 614            for range in ranges.command_outputs.values() {
 615                endpoints.push(Endpoint {
 616                    offset: range.start.to_offset(buffer),
 617                    marker: '',
 618                });
 619            }
 620            for range in ranges.parsed_commands.iter() {
 621                endpoints.push(Endpoint {
 622                    offset: range.start.to_offset(buffer),
 623                    marker: '«',
 624                });
 625            }
 626            for range in ranges.output_sections.iter() {
 627                endpoints.push(Endpoint {
 628                    offset: range.start.to_offset(buffer),
 629                    marker: '',
 630                });
 631            }
 632
 633            for range in ranges.output_sections.iter() {
 634                endpoints.push(Endpoint {
 635                    offset: range.end.to_offset(buffer),
 636                    marker: '',
 637                });
 638            }
 639            for range in ranges.parsed_commands.iter() {
 640                endpoints.push(Endpoint {
 641                    offset: range.end.to_offset(buffer),
 642                    marker: '»',
 643                });
 644            }
 645            for range in ranges.command_outputs.values() {
 646                endpoints.push(Endpoint {
 647                    offset: range.end.to_offset(buffer),
 648                    marker: '',
 649                });
 650            }
 651
 652            endpoints.sort_by_key(|endpoint| endpoint.offset);
 653            let mut offset = 0;
 654            for endpoint in endpoints {
 655                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 656                actual_marked_text.push(endpoint.marker);
 657                offset = endpoint.offset;
 658            }
 659            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 660        });
 661
 662        assert_eq!(actual_marked_text, expected_marked_text);
 663    }
 664}
 665
 666#[gpui::test]
 667async fn test_serialization(cx: &mut TestAppContext) {
 668    cx.update(init_test);
 669
 670    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 671    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 672    let text_thread = cx.new(|cx| {
 673        TextThread::local(
 674            registry.clone(),
 675            None,
 676            prompt_builder.clone(),
 677            Arc::new(SlashCommandWorkingSet::default()),
 678            cx,
 679        )
 680    });
 681    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 682    let message_0 = text_thread.read_with(cx, |text_thread, _| text_thread.message_anchors[0].id);
 683    let message_1 = text_thread.update(cx, |text_thread, cx| {
 684        text_thread
 685            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 686            .unwrap()
 687    });
 688    let message_2 = text_thread.update(cx, |text_thread, cx| {
 689        text_thread
 690            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 691            .unwrap()
 692    });
 693    buffer.update(cx, |buffer, cx| {
 694        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 695        buffer.finalize_last_transaction();
 696    });
 697    let _message_3 = text_thread.update(cx, |text_thread, cx| {
 698        text_thread
 699            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 700            .unwrap()
 701    });
 702    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 703    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 704    assert_eq!(
 705        cx.read(|cx| messages(&text_thread, cx)),
 706        [
 707            (message_0, Role::User, 0..2),
 708            (message_1.id, Role::Assistant, 2..6),
 709            (message_2.id, Role::System, 6..6),
 710        ]
 711    );
 712
 713    let serialized_context = text_thread.read_with(cx, |text_thread, cx| text_thread.serialize(cx));
 714    let deserialized_context = cx.new(|cx| {
 715        TextThread::deserialize(
 716            serialized_context,
 717            Path::new("").into(),
 718            registry.clone(),
 719            prompt_builder.clone(),
 720            Arc::new(SlashCommandWorkingSet::default()),
 721            None,
 722            cx,
 723        )
 724    });
 725    let deserialized_buffer =
 726        deserialized_context.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 727    assert_eq!(
 728        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 729        "a\nb\nc\n"
 730    );
 731    assert_eq!(
 732        cx.read(|cx| messages(&deserialized_context, cx)),
 733        [
 734            (message_0, Role::User, 0..2),
 735            (message_1.id, Role::Assistant, 2..6),
 736            (message_2.id, Role::System, 6..6),
 737        ]
 738    );
 739}
 740
 741#[gpui::test(iterations = 25)]
 742async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 743    cx.update(init_test);
 744
 745    let min_peers = env::var("MIN_PEERS")
 746        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 747        .unwrap_or(2);
 748    let max_peers = env::var("MAX_PEERS")
 749        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 750        .unwrap_or(5);
 751    let operations = env::var("OPERATIONS")
 752        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 753        .unwrap_or(50);
 754
 755    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 756    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 757    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 758    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 759
 760    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 761    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 762    let mut text_threads = Vec::new();
 763
 764    let num_peers = rng.random_range(min_peers..=max_peers);
 765    let context_id = TextThreadId::new();
 766    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 767    for i in 0..num_peers {
 768        let context = cx.new(|cx| {
 769            TextThread::new(
 770                context_id.clone(),
 771                ReplicaId::new(i as u16),
 772                language::Capability::ReadWrite,
 773                registry.clone(),
 774                prompt_builder.clone(),
 775                Arc::new(SlashCommandWorkingSet::default()),
 776                None,
 777                cx,
 778            )
 779        });
 780
 781        cx.update(|cx| {
 782            cx.subscribe(&context, {
 783                let network = network.clone();
 784                move |_, event, _| {
 785                    if let TextThreadEvent::Operation(op) = event {
 786                        network
 787                            .lock()
 788                            .broadcast(ReplicaId::new(i as u16), vec![op.to_proto()]);
 789                    }
 790                }
 791            })
 792            .detach();
 793        });
 794
 795        text_threads.push(context);
 796        network.lock().add_peer(ReplicaId::new(i as u16));
 797    }
 798
 799    let mut mutation_count = operations;
 800
 801    while mutation_count > 0
 802        || !network.lock().is_idle()
 803        || network.lock().contains_disconnected_peers()
 804    {
 805        let context_index = rng.random_range(0..text_threads.len());
 806        let text_thread = &text_threads[context_index];
 807
 808        match rng.random_range(0..100) {
 809            0..=29 if mutation_count > 0 => {
 810                log::info!("Context {}: edit buffer", context_index);
 811                text_thread.update(cx, |text_thread, cx| {
 812                    text_thread
 813                        .buffer()
 814                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
 815                });
 816                mutation_count -= 1;
 817            }
 818            30..=44 if mutation_count > 0 => {
 819                text_thread.update(cx, |text_thread, cx| {
 820                    let range = text_thread.buffer().read(cx).random_byte_range(0, &mut rng);
 821                    log::info!("Context {}: split message at {:?}", context_index, range);
 822                    text_thread.split_message(range, cx);
 823                });
 824                mutation_count -= 1;
 825            }
 826            45..=59 if mutation_count > 0 => {
 827                text_thread.update(cx, |text_thread, cx| {
 828                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 829                        let role = *[Role::User, Role::Assistant, Role::System]
 830                            .choose(&mut rng)
 831                            .unwrap();
 832                        log::info!(
 833                            "Context {}: insert message after {:?} with {:?}",
 834                            context_index,
 835                            message.id,
 836                            role
 837                        );
 838                        text_thread.insert_message_after(message.id, role, MessageStatus::Done, cx);
 839                    }
 840                });
 841                mutation_count -= 1;
 842            }
 843            60..=74 if mutation_count > 0 => {
 844                text_thread.update(cx, |text_thread, cx| {
 845                    let command_text = "/".to_string()
 846                        + slash_commands
 847                            .command_names()
 848                            .choose(&mut rng)
 849                            .unwrap()
 850                            .clone()
 851                            .as_ref();
 852
 853                    let command_range = text_thread.buffer().update(cx, |buffer, cx| {
 854                        let offset = buffer.random_byte_range(0, &mut rng).start;
 855                        buffer.edit(
 856                            [(offset..offset, format!("\n{}\n", command_text))],
 857                            None,
 858                            cx,
 859                        );
 860                        offset + 1..offset + 1 + command_text.len()
 861                    });
 862
 863                    let output_text = RandomCharIter::new(&mut rng)
 864                        .filter(|c| *c != '\r')
 865                        .take(10)
 866                        .collect::<String>();
 867
 868                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
 869                        role: Role::User,
 870                        merge_same_roles: true,
 871                    })];
 872
 873                    let num_sections = rng.random_range(0..=3);
 874                    let mut section_start = 0;
 875                    for _ in 0..num_sections {
 876                        let section_end = output_text.floor_char_boundary(
 877                            rng.random_range(section_start..=output_text.len()),
 878                        );
 879                        events.push(Ok(SlashCommandEvent::StartSection {
 880                            icon: IconName::Ai,
 881                            label: "section".into(),
 882                            metadata: None,
 883                        }));
 884                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 885                            text: output_text[section_start..section_end].to_string(),
 886                            run_commands_in_text: false,
 887                        })));
 888                        events.push(Ok(SlashCommandEvent::EndSection));
 889                        section_start = section_end;
 890                    }
 891
 892                    if section_start < output_text.len() {
 893                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 894                            text: output_text[section_start..].to_string(),
 895                            run_commands_in_text: false,
 896                        })));
 897                    }
 898
 899                    log::info!(
 900                        "Context {}: insert slash command output at {:?} with {:?} events",
 901                        context_index,
 902                        command_range,
 903                        events.len()
 904                    );
 905
 906                    let command_range = text_thread
 907                        .buffer()
 908                        .read(cx)
 909                        .anchor_after(command_range.start)
 910                        ..text_thread
 911                            .buffer()
 912                            .read(cx)
 913                            .anchor_after(command_range.end);
 914                    text_thread.insert_command_output(
 915                        command_range,
 916                        "/command",
 917                        Task::ready(Ok(stream::iter(events).boxed())),
 918                        true,
 919                        cx,
 920                    );
 921                });
 922                cx.run_until_parked();
 923                mutation_count -= 1;
 924            }
 925            75..=84 if mutation_count > 0 => {
 926                text_thread.update(cx, |text_thread, cx| {
 927                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 928                        let new_status = match rng.random_range(0..3) {
 929                            0 => MessageStatus::Done,
 930                            1 => MessageStatus::Pending,
 931                            _ => MessageStatus::Error(SharedString::from("Random error")),
 932                        };
 933                        log::info!(
 934                            "Context {}: update message {:?} status to {:?}",
 935                            context_index,
 936                            message.id,
 937                            new_status
 938                        );
 939                        text_thread.update_metadata(message.id, cx, |metadata| {
 940                            metadata.status = new_status;
 941                        });
 942                    }
 943                });
 944                mutation_count -= 1;
 945            }
 946            _ => {
 947                let replica_id = ReplicaId::new(context_index as u16);
 948                if network.lock().is_disconnected(replica_id) {
 949                    network.lock().reconnect_peer(replica_id, ReplicaId::new(0));
 950
 951                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
 952                        let host_context = &text_threads[0].read(cx);
 953                        let guest_context = text_thread.read(cx);
 954                        (
 955                            guest_context.serialize_ops(&host_context.version(cx), cx),
 956                            host_context.serialize_ops(&guest_context.version(cx), cx),
 957                        )
 958                    });
 959                    let ops_to_send = ops_to_send.await;
 960                    let ops_to_receive = ops_to_receive
 961                        .await
 962                        .into_iter()
 963                        .map(TextThreadOperation::from_proto)
 964                        .collect::<Result<Vec<_>>>()
 965                        .unwrap();
 966                    log::info!(
 967                        "Context {}: reconnecting. Sent {} operations, received {} operations",
 968                        context_index,
 969                        ops_to_send.len(),
 970                        ops_to_receive.len()
 971                    );
 972
 973                    network.lock().broadcast(replica_id, ops_to_send);
 974                    text_thread.update(cx, |text_thread, cx| {
 975                        text_thread.apply_ops(ops_to_receive, cx)
 976                    });
 977                } else if rng.random_bool(0.1) && replica_id != ReplicaId::new(0) {
 978                    log::info!("Context {}: disconnecting", context_index);
 979                    network.lock().disconnect_peer(replica_id);
 980                } else if network.lock().has_unreceived(replica_id) {
 981                    log::info!("Context {}: applying operations", context_index);
 982                    let ops = network.lock().receive(replica_id);
 983                    let ops = ops
 984                        .into_iter()
 985                        .map(TextThreadOperation::from_proto)
 986                        .collect::<Result<Vec<_>>>()
 987                        .unwrap();
 988                    text_thread.update(cx, |text_thread, cx| text_thread.apply_ops(ops, cx));
 989                }
 990            }
 991        }
 992    }
 993
 994    cx.read(|cx| {
 995        let first_context = text_threads[0].read(cx);
 996        for text_thread in &text_threads[1..] {
 997            let text_thread = text_thread.read(cx);
 998            assert!(text_thread.pending_ops.is_empty(), "pending ops: {:?}", text_thread.pending_ops);
 999            assert_eq!(
1000                text_thread.buffer().read(cx).text(),
1001                first_context.buffer().read(cx).text(),
1002                "Context {:?} text != Context 0 text",
1003                text_thread.buffer().read(cx).replica_id()
1004            );
1005            assert_eq!(
1006                text_thread.message_anchors,
1007                first_context.message_anchors,
1008                "Context {:?} messages != Context 0 messages",
1009                text_thread.buffer().read(cx).replica_id()
1010            );
1011            assert_eq!(
1012                text_thread.messages_metadata,
1013                first_context.messages_metadata,
1014                "Context {:?} message metadata != Context 0 message metadata",
1015                text_thread.buffer().read(cx).replica_id()
1016            );
1017            assert_eq!(
1018                text_thread.slash_command_output_sections,
1019                first_context.slash_command_output_sections,
1020                "Context {:?} slash command output sections != Context 0 slash command output sections",
1021                text_thread.buffer().read(cx).replica_id()
1022            );
1023        }
1024    });
1025}
1026
1027#[gpui::test]
1028fn test_mark_cache_anchors(cx: &mut App) {
1029    init_test(cx);
1030
1031    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1032    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1033    let text_thread = cx.new(|cx| {
1034        TextThread::local(
1035            registry,
1036            None,
1037            prompt_builder.clone(),
1038            Arc::new(SlashCommandWorkingSet::default()),
1039            cx,
1040        )
1041    });
1042    let buffer = text_thread.read(cx).buffer().clone();
1043
1044    // Create a test cache configuration
1045    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1046        max_cache_anchors: 3,
1047        should_speculate: true,
1048        min_total_token: 10,
1049    });
1050
1051    let message_1 = text_thread.read(cx).message_anchors[0].clone();
1052
1053    text_thread.update(cx, |text_thread, cx| {
1054        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1055    });
1056
1057    assert_eq!(
1058        messages_cache(&text_thread, cx)
1059            .iter()
1060            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1061            .count(),
1062        0,
1063        "Empty messages should not have any cache anchors."
1064    );
1065
1066    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1067    let message_2 = text_thread
1068        .update(cx, |text_thread, cx| {
1069            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1070        })
1071        .unwrap();
1072
1073    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1074    let message_3 = text_thread
1075        .update(cx, |text_thread, cx| {
1076            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1077        })
1078        .unwrap();
1079    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1080
1081    text_thread.update(cx, |text_thread, cx| {
1082        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1083    });
1084    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1085    assert_eq!(
1086        messages_cache(&text_thread, cx)
1087            .iter()
1088            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1089            .count(),
1090        0,
1091        "Messages should not be marked for cache before going over the token minimum."
1092    );
1093    text_thread.update(cx, |text_thread, _| {
1094        text_thread.token_count = Some(20);
1095    });
1096
1097    text_thread.update(cx, |text_thread, cx| {
1098        text_thread.mark_cache_anchors(cache_configuration, true, cx)
1099    });
1100    assert_eq!(
1101        messages_cache(&text_thread, cx)
1102            .iter()
1103            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1104            .collect::<Vec<bool>>(),
1105        vec![true, true, false],
1106        "Last message should not be an anchor on speculative request."
1107    );
1108
1109    text_thread
1110        .update(cx, |text_thread, cx| {
1111            text_thread.insert_message_after(
1112                message_3.id,
1113                Role::Assistant,
1114                MessageStatus::Pending,
1115                cx,
1116            )
1117        })
1118        .unwrap();
1119
1120    text_thread.update(cx, |text_thread, cx| {
1121        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1122    });
1123    assert_eq!(
1124        messages_cache(&text_thread, cx)
1125            .iter()
1126            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1127            .collect::<Vec<bool>>(),
1128        vec![false, true, true, false],
1129        "Most recent message should also be cached if not a speculative request."
1130    );
1131    text_thread.update(cx, |text_thread, cx| {
1132        text_thread.update_cache_status_for_completion(cx)
1133    });
1134    assert_eq!(
1135        messages_cache(&text_thread, cx)
1136            .iter()
1137            .map(|(_, cache)| cache
1138                .as_ref()
1139                .map_or(None, |cache| Some(cache.status.clone())))
1140            .collect::<Vec<Option<CacheStatus>>>(),
1141        vec![
1142            Some(CacheStatus::Cached),
1143            Some(CacheStatus::Cached),
1144            Some(CacheStatus::Cached),
1145            None
1146        ],
1147        "All user messages prior to anchor should be marked as cached."
1148    );
1149
1150    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1151    text_thread.update(cx, |text_thread, cx| {
1152        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1153    });
1154    assert_eq!(
1155        messages_cache(&text_thread, cx)
1156            .iter()
1157            .map(|(_, cache)| cache
1158                .as_ref()
1159                .map_or(None, |cache| Some(cache.status.clone())))
1160            .collect::<Vec<Option<CacheStatus>>>(),
1161        vec![
1162            Some(CacheStatus::Cached),
1163            Some(CacheStatus::Cached),
1164            Some(CacheStatus::Pending),
1165            None
1166        ],
1167        "Modifying a message should invalidate it's cache but leave previous messages."
1168    );
1169    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1170    text_thread.update(cx, |text_thread, cx| {
1171        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1172    });
1173    assert_eq!(
1174        messages_cache(&text_thread, cx)
1175            .iter()
1176            .map(|(_, cache)| cache
1177                .as_ref()
1178                .map_or(None, |cache| Some(cache.status.clone())))
1179            .collect::<Vec<Option<CacheStatus>>>(),
1180        vec![
1181            Some(CacheStatus::Pending),
1182            Some(CacheStatus::Pending),
1183            Some(CacheStatus::Pending),
1184            None
1185        ],
1186        "Modifying a message should invalidate all future messages."
1187    );
1188}
1189
1190#[gpui::test]
1191async fn test_summarization(cx: &mut TestAppContext) {
1192    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1193
1194    // Initial state should be pending
1195    text_thread.read_with(cx, |text_thread, _| {
1196        assert!(matches!(text_thread.summary(), TextThreadSummary::Pending));
1197        assert_eq!(
1198            text_thread.summary().or_default(),
1199            TextThreadSummary::DEFAULT
1200        );
1201    });
1202
1203    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1204        text_thread.message_anchors[0].clone()
1205    });
1206    text_thread.update(cx, |context, cx| {
1207        context
1208            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1209            .unwrap();
1210    });
1211
1212    // Send a message
1213    text_thread.update(cx, |text_thread, cx| {
1214        text_thread.assist(cx);
1215    });
1216
1217    simulate_successful_response(&fake_model, cx);
1218
1219    // Should start generating summary when there are >= 2 messages
1220    text_thread.read_with(cx, |text_thread, _| {
1221        assert!(!text_thread.summary().content().unwrap().done);
1222    });
1223
1224    cx.run_until_parked();
1225    fake_model.send_last_completion_stream_text_chunk("Brief");
1226    fake_model.send_last_completion_stream_text_chunk(" Introduction");
1227    fake_model.end_last_completion_stream();
1228    cx.run_until_parked();
1229
1230    // Summary should be set
1231    text_thread.read_with(cx, |text_thread, _| {
1232        assert_eq!(text_thread.summary().or_default(), "Brief Introduction");
1233    });
1234
1235    // We should be able to manually set a summary
1236    text_thread.update(cx, |text_thread, cx| {
1237        text_thread.set_custom_summary("Brief Intro".into(), cx);
1238    });
1239
1240    text_thread.read_with(cx, |text_thread, _| {
1241        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1242    });
1243}
1244
1245#[gpui::test]
1246async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1247    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1248
1249    test_summarize_error(&fake_model, &text_thread, cx);
1250
1251    // Now we should be able to set a summary
1252    text_thread.update(cx, |text_thread, cx| {
1253        text_thread.set_custom_summary("Brief Intro".into(), cx);
1254    });
1255
1256    text_thread.read_with(cx, |text_thread, _| {
1257        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1258    });
1259}
1260
1261#[gpui::test]
1262async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1263    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1264
1265    test_summarize_error(&fake_model, &text_thread, cx);
1266
1267    // Sending another message should not trigger another summarize request
1268    text_thread.update(cx, |text_thread, cx| {
1269        text_thread.assist(cx);
1270    });
1271
1272    simulate_successful_response(&fake_model, cx);
1273
1274    text_thread.read_with(cx, |text_thread, _| {
1275        // State is still Error, not Generating
1276        assert!(matches!(text_thread.summary(), TextThreadSummary::Error));
1277    });
1278
1279    // But the summarize request can be invoked manually
1280    text_thread.update(cx, |text_thread, cx| {
1281        text_thread.summarize(true, cx);
1282    });
1283
1284    text_thread.read_with(cx, |text_thread, _| {
1285        assert!(!text_thread.summary().content().unwrap().done);
1286    });
1287
1288    cx.run_until_parked();
1289    fake_model.send_last_completion_stream_text_chunk("A successful summary");
1290    fake_model.end_last_completion_stream();
1291    cx.run_until_parked();
1292
1293    text_thread.read_with(cx, |text_thread, _| {
1294        assert_eq!(text_thread.summary().or_default(), "A successful summary");
1295    });
1296}
1297
1298fn test_summarize_error(
1299    model: &Arc<FakeLanguageModel>,
1300    text_thread: &Entity<TextThread>,
1301    cx: &mut TestAppContext,
1302) {
1303    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1304        text_thread.message_anchors[0].clone()
1305    });
1306    text_thread.update(cx, |text_thread, cx| {
1307        text_thread
1308            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1309            .unwrap();
1310    });
1311
1312    // Send a message
1313    text_thread.update(cx, |text_thread, cx| {
1314        text_thread.assist(cx);
1315    });
1316
1317    simulate_successful_response(model, cx);
1318
1319    text_thread.read_with(cx, |text_thread, _| {
1320        assert!(!text_thread.summary().content().unwrap().done);
1321    });
1322
1323    // Simulate summary request ending
1324    cx.run_until_parked();
1325    model.end_last_completion_stream();
1326    cx.run_until_parked();
1327
1328    // State is set to Error and default message
1329    text_thread.read_with(cx, |text_thread, _| {
1330        assert_eq!(*text_thread.summary(), TextThreadSummary::Error);
1331        assert_eq!(
1332            text_thread.summary().or_default(),
1333            TextThreadSummary::DEFAULT
1334        );
1335    });
1336}
1337
1338fn setup_context_editor_with_fake_model(
1339    cx: &mut TestAppContext,
1340) -> (Entity<TextThread>, Arc<FakeLanguageModel>) {
1341    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1342
1343    let fake_provider = Arc::new(FakeLanguageModelProvider::default());
1344    let fake_model = Arc::new(fake_provider.test_model());
1345
1346    cx.update(|cx| {
1347        init_test(cx);
1348        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1349            let configured_model = ConfiguredModel {
1350                provider: fake_provider.clone(),
1351                model: fake_model.clone(),
1352            };
1353            registry.set_default_model(Some(configured_model.clone()), cx);
1354            registry.set_thread_summary_model(Some(configured_model), cx);
1355        })
1356    });
1357
1358    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1359    let context = cx.new(|cx| {
1360        TextThread::local(
1361            registry,
1362            None,
1363            prompt_builder.clone(),
1364            Arc::new(SlashCommandWorkingSet::default()),
1365            cx,
1366        )
1367    });
1368
1369    (context, fake_model)
1370}
1371
1372fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1373    cx.run_until_parked();
1374    fake_model.send_last_completion_stream_text_chunk("Assistant response");
1375    fake_model.end_last_completion_stream();
1376    cx.run_until_parked();
1377}
1378
1379fn messages(context: &Entity<TextThread>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1380    context
1381        .read(cx)
1382        .messages(cx)
1383        .map(|message| (message.id, message.role, message.offset_range))
1384        .collect()
1385}
1386
1387fn messages_cache(
1388    context: &Entity<TextThread>,
1389    cx: &App,
1390) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1391    context
1392        .read(cx)
1393        .messages(cx)
1394        .map(|message| (message.id, message.cache))
1395        .collect()
1396}
1397
1398fn init_test(cx: &mut App) {
1399    let settings_store = SettingsStore::test(cx);
1400    prompt_store::init(cx);
1401    LanguageModelRegistry::test(cx);
1402    cx.set_global(settings_store);
1403}
1404
1405#[derive(Clone)]
1406struct FakeSlashCommand(String);
1407
1408impl SlashCommand for FakeSlashCommand {
1409    fn name(&self) -> String {
1410        self.0.clone()
1411    }
1412
1413    fn description(&self) -> String {
1414        format!("Fake slash command: {}", self.0)
1415    }
1416
1417    fn menu_text(&self) -> String {
1418        format!("Run fake command: {}", self.0)
1419    }
1420
1421    fn complete_argument(
1422        self: Arc<Self>,
1423        _arguments: &[String],
1424        _cancel: Arc<AtomicBool>,
1425        _workspace: Option<WeakEntity<Workspace>>,
1426        _window: &mut Window,
1427        _cx: &mut App,
1428    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1429        Task::ready(Ok(vec![]))
1430    }
1431
1432    fn requires_argument(&self) -> bool {
1433        false
1434    }
1435
1436    fn run(
1437        self: Arc<Self>,
1438        _arguments: &[String],
1439        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1440        _context_buffer: BufferSnapshot,
1441        _workspace: WeakEntity<Workspace>,
1442        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1443        _window: &mut Window,
1444        _cx: &mut App,
1445    ) -> Task<SlashCommandResult> {
1446        Task::ready(Ok(SlashCommandOutput {
1447            text: format!("Executed fake command: {}", self.0),
1448            sections: vec![],
1449            run_commands_in_text: false,
1450        }
1451        .into_event_stream()))
1452    }
1453}