assistant_text_thread_tests.rs

   1use crate::{
   2    CacheStatus, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, TextThread,
   3    TextThreadEvent, TextThreadId, TextThreadOperation, TextThreadSummary,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{
  20    ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
  21    fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
  22};
  23use parking_lot::Mutex;
  24use pretty_assertions::assert_eq;
  25use prompt_store::PromptBuilder;
  26use rand::prelude::*;
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{
  30    cell::RefCell,
  31    env,
  32    ops::Range,
  33    path::Path,
  34    rc::Rc,
  35    sync::{Arc, atomic::AtomicBool},
  36};
  37use text::{ReplicaId, ToOffset, network::Network};
  38use ui::{IconName, Window};
  39use unindent::Unindent;
  40use util::RandomCharIter;
  41use workspace::Workspace;
  42
  43#[gpui::test]
  44fn test_inserting_and_removing_messages(cx: &mut App) {
  45    init_test(cx);
  46
  47    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  48    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  49    let text_thread = cx.new(|cx| {
  50        TextThread::local(
  51            registry,
  52            prompt_builder.clone(),
  53            Arc::new(SlashCommandWorkingSet::default()),
  54            cx,
  55        )
  56    });
  57    let buffer = text_thread.read(cx).buffer().clone();
  58
  59    let message_1 = text_thread.read(cx).message_anchors[0].clone();
  60    assert_eq!(
  61        messages(&text_thread, cx),
  62        vec![(message_1.id, Role::User, 0..0)]
  63    );
  64
  65    let message_2 = text_thread.update(cx, |context, cx| {
  66        context
  67            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  68            .unwrap()
  69    });
  70    assert_eq!(
  71        messages(&text_thread, cx),
  72        vec![
  73            (message_1.id, Role::User, 0..1),
  74            (message_2.id, Role::Assistant, 1..1)
  75        ]
  76    );
  77
  78    buffer.update(cx, |buffer, cx| {
  79        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  80    });
  81    assert_eq!(
  82        messages(&text_thread, cx),
  83        vec![
  84            (message_1.id, Role::User, 0..2),
  85            (message_2.id, Role::Assistant, 2..3)
  86        ]
  87    );
  88
  89    let message_3 = text_thread.update(cx, |context, cx| {
  90        context
  91            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  92            .unwrap()
  93    });
  94    assert_eq!(
  95        messages(&text_thread, cx),
  96        vec![
  97            (message_1.id, Role::User, 0..2),
  98            (message_2.id, Role::Assistant, 2..4),
  99            (message_3.id, Role::User, 4..4)
 100        ]
 101    );
 102
 103    let message_4 = text_thread.update(cx, |context, cx| {
 104        context
 105            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 106            .unwrap()
 107    });
 108    assert_eq!(
 109        messages(&text_thread, cx),
 110        vec![
 111            (message_1.id, Role::User, 0..2),
 112            (message_2.id, Role::Assistant, 2..4),
 113            (message_4.id, Role::User, 4..5),
 114            (message_3.id, Role::User, 5..5),
 115        ]
 116    );
 117
 118    buffer.update(cx, |buffer, cx| {
 119        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 120    });
 121    assert_eq!(
 122        messages(&text_thread, cx),
 123        vec![
 124            (message_1.id, Role::User, 0..2),
 125            (message_2.id, Role::Assistant, 2..4),
 126            (message_4.id, Role::User, 4..6),
 127            (message_3.id, Role::User, 6..7),
 128        ]
 129    );
 130
 131    // Deleting across message boundaries merges the messages.
 132    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 133    assert_eq!(
 134        messages(&text_thread, cx),
 135        vec![
 136            (message_1.id, Role::User, 0..3),
 137            (message_3.id, Role::User, 3..4),
 138        ]
 139    );
 140
 141    // Undoing the deletion should also undo the merge.
 142    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 143    assert_eq!(
 144        messages(&text_thread, cx),
 145        vec![
 146            (message_1.id, Role::User, 0..2),
 147            (message_2.id, Role::Assistant, 2..4),
 148            (message_4.id, Role::User, 4..6),
 149            (message_3.id, Role::User, 6..7),
 150        ]
 151    );
 152
 153    // Redoing the deletion should also redo the merge.
 154    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 155    assert_eq!(
 156        messages(&text_thread, cx),
 157        vec![
 158            (message_1.id, Role::User, 0..3),
 159            (message_3.id, Role::User, 3..4),
 160        ]
 161    );
 162
 163    // Ensure we can still insert after a merged message.
 164    let message_5 = text_thread.update(cx, |context, cx| {
 165        context
 166            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 167            .unwrap()
 168    });
 169    assert_eq!(
 170        messages(&text_thread, cx),
 171        vec![
 172            (message_1.id, Role::User, 0..3),
 173            (message_5.id, Role::System, 3..4),
 174            (message_3.id, Role::User, 4..5)
 175        ]
 176    );
 177}
 178
 179#[gpui::test]
 180fn test_message_splitting(cx: &mut App) {
 181    init_test(cx);
 182
 183    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 184
 185    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 186    let text_thread = cx.new(|cx| {
 187        TextThread::local(
 188            registry.clone(),
 189            prompt_builder.clone(),
 190            Arc::new(SlashCommandWorkingSet::default()),
 191            cx,
 192        )
 193    });
 194    let buffer = text_thread.read(cx).buffer().clone();
 195
 196    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 197    assert_eq!(
 198        messages(&text_thread, cx),
 199        vec![(message_1.id, Role::User, 0..0)]
 200    );
 201
 202    buffer.update(cx, |buffer, cx| {
 203        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 204    });
 205
 206    let (_, message_2) =
 207        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 208    let message_2 = message_2.unwrap();
 209
 210    // We recycle newlines in the middle of a split message
 211    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 212    assert_eq!(
 213        messages(&text_thread, cx),
 214        vec![
 215            (message_1.id, Role::User, 0..4),
 216            (message_2.id, Role::User, 4..16),
 217        ]
 218    );
 219
 220    let (_, message_3) =
 221        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 222    let message_3 = message_3.unwrap();
 223
 224    // We don't recycle newlines at the end of a split message
 225    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 226    assert_eq!(
 227        messages(&text_thread, cx),
 228        vec![
 229            (message_1.id, Role::User, 0..4),
 230            (message_3.id, Role::User, 4..5),
 231            (message_2.id, Role::User, 5..17),
 232        ]
 233    );
 234
 235    let (_, message_4) =
 236        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 237    let message_4 = message_4.unwrap();
 238    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 239    assert_eq!(
 240        messages(&text_thread, cx),
 241        vec![
 242            (message_1.id, Role::User, 0..4),
 243            (message_3.id, Role::User, 4..5),
 244            (message_2.id, Role::User, 5..9),
 245            (message_4.id, Role::User, 9..17),
 246        ]
 247    );
 248
 249    let (_, message_5) =
 250        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 251    let message_5 = message_5.unwrap();
 252    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 253    assert_eq!(
 254        messages(&text_thread, cx),
 255        vec![
 256            (message_1.id, Role::User, 0..4),
 257            (message_3.id, Role::User, 4..5),
 258            (message_2.id, Role::User, 5..9),
 259            (message_4.id, Role::User, 9..10),
 260            (message_5.id, Role::User, 10..18),
 261        ]
 262    );
 263
 264    let (message_6, message_7) =
 265        text_thread.update(cx, |text_thread, cx| text_thread.split_message(14..16, cx));
 266    let message_6 = message_6.unwrap();
 267    let message_7 = message_7.unwrap();
 268    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 269    assert_eq!(
 270        messages(&text_thread, cx),
 271        vec![
 272            (message_1.id, Role::User, 0..4),
 273            (message_3.id, Role::User, 4..5),
 274            (message_2.id, Role::User, 5..9),
 275            (message_4.id, Role::User, 9..10),
 276            (message_5.id, Role::User, 10..14),
 277            (message_6.id, Role::User, 14..17),
 278            (message_7.id, Role::User, 17..19),
 279        ]
 280    );
 281}
 282
 283#[gpui::test]
 284fn test_messages_for_offsets(cx: &mut App) {
 285    init_test(cx);
 286
 287    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 288    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 289    let text_thread = cx.new(|cx| {
 290        TextThread::local(
 291            registry,
 292            prompt_builder.clone(),
 293            Arc::new(SlashCommandWorkingSet::default()),
 294            cx,
 295        )
 296    });
 297    let buffer = text_thread.read(cx).buffer().clone();
 298
 299    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 300    assert_eq!(
 301        messages(&text_thread, cx),
 302        vec![(message_1.id, Role::User, 0..0)]
 303    );
 304
 305    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 306    let message_2 = text_thread
 307        .update(cx, |text_thread, cx| {
 308            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 309        })
 310        .unwrap();
 311    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 312
 313    let message_3 = text_thread
 314        .update(cx, |text_thread, cx| {
 315            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 316        })
 317        .unwrap();
 318    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 319
 320    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 321    assert_eq!(
 322        messages(&text_thread, cx),
 323        vec![
 324            (message_1.id, Role::User, 0..4),
 325            (message_2.id, Role::User, 4..8),
 326            (message_3.id, Role::User, 8..11)
 327        ]
 328    );
 329
 330    assert_eq!(
 331        message_ids_for_offsets(&text_thread, &[0, 4, 9], cx),
 332        [message_1.id, message_2.id, message_3.id]
 333    );
 334    assert_eq!(
 335        message_ids_for_offsets(&text_thread, &[0, 1, 11], cx),
 336        [message_1.id, message_3.id]
 337    );
 338
 339    let message_4 = text_thread
 340        .update(cx, |text_thread, cx| {
 341            text_thread.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 342        })
 343        .unwrap();
 344    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 345    assert_eq!(
 346        messages(&text_thread, cx),
 347        vec![
 348            (message_1.id, Role::User, 0..4),
 349            (message_2.id, Role::User, 4..8),
 350            (message_3.id, Role::User, 8..12),
 351            (message_4.id, Role::User, 12..12)
 352        ]
 353    );
 354    assert_eq!(
 355        message_ids_for_offsets(&text_thread, &[0, 4, 8, 12], cx),
 356        [message_1.id, message_2.id, message_3.id, message_4.id]
 357    );
 358
 359    fn message_ids_for_offsets(
 360        context: &Entity<TextThread>,
 361        offsets: &[usize],
 362        cx: &App,
 363    ) -> Vec<MessageId> {
 364        context
 365            .read(cx)
 366            .messages_for_offsets(offsets.iter().copied(), cx)
 367            .into_iter()
 368            .map(|message| message.id)
 369            .collect()
 370    }
 371}
 372
 373#[gpui::test]
 374async fn test_slash_commands(cx: &mut TestAppContext) {
 375    cx.update(init_test);
 376
 377    let fs = FakeFs::new(cx.background_executor.clone());
 378
 379    fs.insert_tree(
 380        "/test",
 381        json!({
 382            "src": {
 383                "lib.rs": "fn one() -> usize { 1 }",
 384                "main.rs": "
 385                    use crate::one;
 386                    fn main() { one(); }
 387                ".unindent(),
 388            }
 389        }),
 390    )
 391    .await;
 392
 393    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 394    slash_command_registry.register_command(FileSlashCommand, false);
 395
 396    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 397    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 398    let text_thread = cx.new(|cx| {
 399        TextThread::local(
 400            registry.clone(),
 401            prompt_builder.clone(),
 402            Arc::new(SlashCommandWorkingSet::default()),
 403            cx,
 404        )
 405    });
 406
 407    #[derive(Default)]
 408    struct ContextRanges {
 409        parsed_commands: HashSet<Range<language::Anchor>>,
 410        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 411        output_sections: HashSet<Range<language::Anchor>>,
 412    }
 413
 414    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 415    text_thread.update(cx, |_, cx| {
 416        cx.subscribe(&text_thread, {
 417            let context_ranges = context_ranges.clone();
 418            move |text_thread, _, event, _| {
 419                let mut context_ranges = context_ranges.borrow_mut();
 420                match event {
 421                    TextThreadEvent::InvokedSlashCommandChanged { command_id } => {
 422                        let command = text_thread.invoked_slash_command(command_id).unwrap();
 423                        context_ranges
 424                            .command_outputs
 425                            .insert(*command_id, command.range.clone());
 426                    }
 427                    TextThreadEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 428                        for range in removed {
 429                            context_ranges.parsed_commands.remove(range);
 430                        }
 431                        for command in updated {
 432                            context_ranges
 433                                .parsed_commands
 434                                .insert(command.source_range.clone());
 435                        }
 436                    }
 437                    TextThreadEvent::SlashCommandOutputSectionAdded { section } => {
 438                        context_ranges.output_sections.insert(section.range.clone());
 439                    }
 440                    _ => {}
 441                }
 442            }
 443        })
 444        .detach();
 445    });
 446
 447    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 448
 449    // Insert a slash command
 450    buffer.update(cx, |buffer, cx| {
 451        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 452    });
 453    assert_text_and_context_ranges(
 454        &buffer,
 455        &context_ranges,
 456        &"
 457        «/file src/lib.rs»"
 458            .unindent(),
 459        cx,
 460    );
 461
 462    // Edit the argument of the slash command.
 463    buffer.update(cx, |buffer, cx| {
 464        let edit_offset = buffer.text().find("lib.rs").unwrap();
 465        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 466    });
 467    assert_text_and_context_ranges(
 468        &buffer,
 469        &context_ranges,
 470        &"
 471        «/file src/main.rs»"
 472            .unindent(),
 473        cx,
 474    );
 475
 476    // Edit the name of the slash command, using one that doesn't exist.
 477    buffer.update(cx, |buffer, cx| {
 478        let edit_offset = buffer.text().find("/file").unwrap();
 479        buffer.edit(
 480            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 481            None,
 482            cx,
 483        );
 484    });
 485    assert_text_and_context_ranges(
 486        &buffer,
 487        &context_ranges,
 488        &"
 489        /unknown src/main.rs"
 490            .unindent(),
 491        cx,
 492    );
 493
 494    // Undoing the insertion of an non-existent slash command resorts the previous one.
 495    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 496    assert_text_and_context_ranges(
 497        &buffer,
 498        &context_ranges,
 499        &"
 500        «/file src/main.rs»"
 501            .unindent(),
 502        cx,
 503    );
 504
 505    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 506    text_thread.update(cx, |text_thread, cx| {
 507        let command_source_range = text_thread.parsed_slash_commands[0].source_range.clone();
 508        text_thread.insert_command_output(
 509            command_source_range,
 510            "file",
 511            Task::ready(Ok(command_output_rx.boxed())),
 512            true,
 513            cx,
 514        );
 515    });
 516    assert_text_and_context_ranges(
 517        &buffer,
 518        &context_ranges,
 519        &"
 520        ⟦«/file src/main.rs»
 521        …⟧
 522        "
 523        .unindent(),
 524        cx,
 525    );
 526
 527    command_output_tx
 528        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 529            icon: IconName::Ai,
 530            label: "src/main.rs".into(),
 531            metadata: None,
 532        }))
 533        .unwrap();
 534    command_output_tx
 535        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 536        .unwrap();
 537    cx.run_until_parked();
 538    assert_text_and_context_ranges(
 539        &buffer,
 540        &context_ranges,
 541        &"
 542        ⟦«/file src/main.rs»
 543        src/main.rs…⟧
 544        "
 545        .unindent(),
 546        cx,
 547    );
 548
 549    command_output_tx
 550        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 551        .unwrap();
 552    cx.run_until_parked();
 553    assert_text_and_context_ranges(
 554        &buffer,
 555        &context_ranges,
 556        &"
 557        ⟦«/file src/main.rs»
 558        src/main.rs
 559        fn main() {}…⟧
 560        "
 561        .unindent(),
 562        cx,
 563    );
 564
 565    command_output_tx
 566        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 567        .unwrap();
 568    cx.run_until_parked();
 569    assert_text_and_context_ranges(
 570        &buffer,
 571        &context_ranges,
 572        &"
 573        ⟦«/file src/main.rs»
 574        ⟪src/main.rs
 575        fn main() {}⟫…⟧
 576        "
 577        .unindent(),
 578        cx,
 579    );
 580
 581    drop(command_output_tx);
 582    cx.run_until_parked();
 583    assert_text_and_context_ranges(
 584        &buffer,
 585        &context_ranges,
 586        &"
 587        ⟦⟪src/main.rs
 588        fn main() {}⟫⟧
 589        "
 590        .unindent(),
 591        cx,
 592    );
 593
 594    #[track_caller]
 595    fn assert_text_and_context_ranges(
 596        buffer: &Entity<Buffer>,
 597        ranges: &RefCell<ContextRanges>,
 598        expected_marked_text: &str,
 599        cx: &mut TestAppContext,
 600    ) {
 601        let mut actual_marked_text = String::new();
 602        buffer.update(cx, |buffer, _| {
 603            struct Endpoint {
 604                offset: usize,
 605                marker: char,
 606            }
 607
 608            let ranges = ranges.borrow();
 609            let mut endpoints = Vec::new();
 610            for range in ranges.command_outputs.values() {
 611                endpoints.push(Endpoint {
 612                    offset: range.start.to_offset(buffer),
 613                    marker: '',
 614                });
 615            }
 616            for range in ranges.parsed_commands.iter() {
 617                endpoints.push(Endpoint {
 618                    offset: range.start.to_offset(buffer),
 619                    marker: '«',
 620                });
 621            }
 622            for range in ranges.output_sections.iter() {
 623                endpoints.push(Endpoint {
 624                    offset: range.start.to_offset(buffer),
 625                    marker: '',
 626                });
 627            }
 628
 629            for range in ranges.output_sections.iter() {
 630                endpoints.push(Endpoint {
 631                    offset: range.end.to_offset(buffer),
 632                    marker: '',
 633                });
 634            }
 635            for range in ranges.parsed_commands.iter() {
 636                endpoints.push(Endpoint {
 637                    offset: range.end.to_offset(buffer),
 638                    marker: '»',
 639                });
 640            }
 641            for range in ranges.command_outputs.values() {
 642                endpoints.push(Endpoint {
 643                    offset: range.end.to_offset(buffer),
 644                    marker: '',
 645                });
 646            }
 647
 648            endpoints.sort_by_key(|endpoint| endpoint.offset);
 649            let mut offset = 0;
 650            for endpoint in endpoints {
 651                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 652                actual_marked_text.push(endpoint.marker);
 653                offset = endpoint.offset;
 654            }
 655            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 656        });
 657
 658        assert_eq!(actual_marked_text, expected_marked_text);
 659    }
 660}
 661
 662#[gpui::test]
 663async fn test_serialization(cx: &mut TestAppContext) {
 664    cx.update(init_test);
 665
 666    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 667    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 668    let text_thread = cx.new(|cx| {
 669        TextThread::local(
 670            registry.clone(),
 671            prompt_builder.clone(),
 672            Arc::new(SlashCommandWorkingSet::default()),
 673            cx,
 674        )
 675    });
 676    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 677    let message_0 = text_thread.read_with(cx, |text_thread, _| text_thread.message_anchors[0].id);
 678    let message_1 = text_thread.update(cx, |text_thread, cx| {
 679        text_thread
 680            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 681            .unwrap()
 682    });
 683    let message_2 = text_thread.update(cx, |text_thread, cx| {
 684        text_thread
 685            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 686            .unwrap()
 687    });
 688    buffer.update(cx, |buffer, cx| {
 689        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 690        buffer.finalize_last_transaction();
 691    });
 692    let _message_3 = text_thread.update(cx, |text_thread, cx| {
 693        text_thread
 694            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 695            .unwrap()
 696    });
 697    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 698    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 699    assert_eq!(
 700        cx.read(|cx| messages(&text_thread, cx)),
 701        [
 702            (message_0, Role::User, 0..2),
 703            (message_1.id, Role::Assistant, 2..6),
 704            (message_2.id, Role::System, 6..6),
 705        ]
 706    );
 707
 708    let serialized_context = text_thread.read_with(cx, |text_thread, cx| text_thread.serialize(cx));
 709    let deserialized_context = cx.new(|cx| {
 710        TextThread::deserialize(
 711            serialized_context,
 712            Path::new("").into(),
 713            registry.clone(),
 714            prompt_builder.clone(),
 715            Arc::new(SlashCommandWorkingSet::default()),
 716            cx,
 717        )
 718    });
 719    let deserialized_buffer =
 720        deserialized_context.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 721    assert_eq!(
 722        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 723        "a\nb\nc\n"
 724    );
 725    assert_eq!(
 726        cx.read(|cx| messages(&deserialized_context, cx)),
 727        [
 728            (message_0, Role::User, 0..2),
 729            (message_1.id, Role::Assistant, 2..6),
 730            (message_2.id, Role::System, 6..6),
 731        ]
 732    );
 733}
 734
 735#[gpui::test(iterations = 25)]
 736async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 737    cx.update(init_test);
 738
 739    let min_peers = env::var("MIN_PEERS")
 740        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 741        .unwrap_or(2);
 742    let max_peers = env::var("MAX_PEERS")
 743        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 744        .unwrap_or(5);
 745    let operations = env::var("OPERATIONS")
 746        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 747        .unwrap_or(50);
 748
 749    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 750    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 751    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 752    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 753
 754    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 755    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 756    let mut text_threads = Vec::new();
 757
 758    let num_peers = rng.random_range(min_peers..=max_peers);
 759    let context_id = TextThreadId::new();
 760    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 761    for i in 0..num_peers {
 762        let context = cx.new(|cx| {
 763            TextThread::new(
 764                context_id.clone(),
 765                ReplicaId::new(i as u16),
 766                language::Capability::ReadWrite,
 767                registry.clone(),
 768                prompt_builder.clone(),
 769                Arc::new(SlashCommandWorkingSet::default()),
 770                cx,
 771            )
 772        });
 773
 774        cx.update(|cx| {
 775            cx.subscribe(&context, {
 776                let network = network.clone();
 777                move |_, event, _| {
 778                    if let TextThreadEvent::Operation(op) = event {
 779                        network
 780                            .lock()
 781                            .broadcast(ReplicaId::new(i as u16), vec![op.to_proto()]);
 782                    }
 783                }
 784            })
 785            .detach();
 786        });
 787
 788        text_threads.push(context);
 789        network.lock().add_peer(ReplicaId::new(i as u16));
 790    }
 791
 792    let mut mutation_count = operations;
 793
 794    while mutation_count > 0
 795        || !network.lock().is_idle()
 796        || network.lock().contains_disconnected_peers()
 797    {
 798        let context_index = rng.random_range(0..text_threads.len());
 799        let text_thread = &text_threads[context_index];
 800
 801        match rng.random_range(0..100) {
 802            0..=29 if mutation_count > 0 => {
 803                log::info!("Context {}: edit buffer", context_index);
 804                text_thread.update(cx, |text_thread, cx| {
 805                    text_thread
 806                        .buffer()
 807                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
 808                });
 809                mutation_count -= 1;
 810            }
 811            30..=44 if mutation_count > 0 => {
 812                text_thread.update(cx, |text_thread, cx| {
 813                    let range = text_thread.buffer().read(cx).random_byte_range(0, &mut rng);
 814                    log::info!("Context {}: split message at {:?}", context_index, range);
 815                    text_thread.split_message(range, cx);
 816                });
 817                mutation_count -= 1;
 818            }
 819            45..=59 if mutation_count > 0 => {
 820                text_thread.update(cx, |text_thread, cx| {
 821                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 822                        let role = *[Role::User, Role::Assistant, Role::System]
 823                            .choose(&mut rng)
 824                            .unwrap();
 825                        log::info!(
 826                            "Context {}: insert message after {:?} with {:?}",
 827                            context_index,
 828                            message.id,
 829                            role
 830                        );
 831                        text_thread.insert_message_after(message.id, role, MessageStatus::Done, cx);
 832                    }
 833                });
 834                mutation_count -= 1;
 835            }
 836            60..=74 if mutation_count > 0 => {
 837                text_thread.update(cx, |text_thread, cx| {
 838                    let command_text = "/".to_string()
 839                        + slash_commands
 840                            .command_names()
 841                            .choose(&mut rng)
 842                            .unwrap()
 843                            .clone()
 844                            .as_ref();
 845
 846                    let command_range = text_thread.buffer().update(cx, |buffer, cx| {
 847                        let offset = buffer.random_byte_range(0, &mut rng).start;
 848                        buffer.edit(
 849                            [(offset..offset, format!("\n{}\n", command_text))],
 850                            None,
 851                            cx,
 852                        );
 853                        offset + 1..offset + 1 + command_text.len()
 854                    });
 855
 856                    let output_text = RandomCharIter::new(&mut rng)
 857                        .filter(|c| *c != '\r')
 858                        .take(10)
 859                        .collect::<String>();
 860
 861                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
 862                        role: Role::User,
 863                        merge_same_roles: true,
 864                    })];
 865
 866                    let num_sections = rng.random_range(0..=3);
 867                    let mut section_start = 0;
 868                    for _ in 0..num_sections {
 869                        let section_end = output_text.floor_char_boundary(
 870                            rng.random_range(section_start..=output_text.len()),
 871                        );
 872                        events.push(Ok(SlashCommandEvent::StartSection {
 873                            icon: IconName::Ai,
 874                            label: "section".into(),
 875                            metadata: None,
 876                        }));
 877                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 878                            text: output_text[section_start..section_end].to_string(),
 879                            run_commands_in_text: false,
 880                        })));
 881                        events.push(Ok(SlashCommandEvent::EndSection));
 882                        section_start = section_end;
 883                    }
 884
 885                    if section_start < output_text.len() {
 886                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 887                            text: output_text[section_start..].to_string(),
 888                            run_commands_in_text: false,
 889                        })));
 890                    }
 891
 892                    log::info!(
 893                        "Context {}: insert slash command output at {:?} with {:?} events",
 894                        context_index,
 895                        command_range,
 896                        events.len()
 897                    );
 898
 899                    let command_range = text_thread
 900                        .buffer()
 901                        .read(cx)
 902                        .anchor_after(command_range.start)
 903                        ..text_thread
 904                            .buffer()
 905                            .read(cx)
 906                            .anchor_after(command_range.end);
 907                    text_thread.insert_command_output(
 908                        command_range,
 909                        "/command",
 910                        Task::ready(Ok(stream::iter(events).boxed())),
 911                        true,
 912                        cx,
 913                    );
 914                });
 915                cx.run_until_parked();
 916                mutation_count -= 1;
 917            }
 918            75..=84 if mutation_count > 0 => {
 919                text_thread.update(cx, |text_thread, cx| {
 920                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 921                        let new_status = match rng.random_range(0..3) {
 922                            0 => MessageStatus::Done,
 923                            1 => MessageStatus::Pending,
 924                            _ => MessageStatus::Error(SharedString::from("Random error")),
 925                        };
 926                        log::info!(
 927                            "Context {}: update message {:?} status to {:?}",
 928                            context_index,
 929                            message.id,
 930                            new_status
 931                        );
 932                        text_thread.update_metadata(message.id, cx, |metadata| {
 933                            metadata.status = new_status;
 934                        });
 935                    }
 936                });
 937                mutation_count -= 1;
 938            }
 939            _ => {
 940                let replica_id = ReplicaId::new(context_index as u16);
 941                if network.lock().is_disconnected(replica_id) {
 942                    network.lock().reconnect_peer(replica_id, ReplicaId::new(0));
 943
 944                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
 945                        let host_context = &text_threads[0].read(cx);
 946                        let guest_context = text_thread.read(cx);
 947                        (
 948                            guest_context.serialize_ops(&host_context.version(cx), cx),
 949                            host_context.serialize_ops(&guest_context.version(cx), cx),
 950                        )
 951                    });
 952                    let ops_to_send = ops_to_send.await;
 953                    let ops_to_receive = ops_to_receive
 954                        .await
 955                        .into_iter()
 956                        .map(TextThreadOperation::from_proto)
 957                        .collect::<Result<Vec<_>>>()
 958                        .unwrap();
 959                    log::info!(
 960                        "Context {}: reconnecting. Sent {} operations, received {} operations",
 961                        context_index,
 962                        ops_to_send.len(),
 963                        ops_to_receive.len()
 964                    );
 965
 966                    network.lock().broadcast(replica_id, ops_to_send);
 967                    text_thread.update(cx, |text_thread, cx| {
 968                        text_thread.apply_ops(ops_to_receive, cx)
 969                    });
 970                } else if rng.random_bool(0.1) && replica_id != ReplicaId::new(0) {
 971                    log::info!("Context {}: disconnecting", context_index);
 972                    network.lock().disconnect_peer(replica_id);
 973                } else if network.lock().has_unreceived(replica_id) {
 974                    log::info!("Context {}: applying operations", context_index);
 975                    let ops = network.lock().receive(replica_id);
 976                    let ops = ops
 977                        .into_iter()
 978                        .map(TextThreadOperation::from_proto)
 979                        .collect::<Result<Vec<_>>>()
 980                        .unwrap();
 981                    text_thread.update(cx, |text_thread, cx| text_thread.apply_ops(ops, cx));
 982                }
 983            }
 984        }
 985    }
 986
 987    cx.read(|cx| {
 988        let first_context = text_threads[0].read(cx);
 989        for text_thread in &text_threads[1..] {
 990            let text_thread = text_thread.read(cx);
 991            assert!(text_thread.pending_ops.is_empty(), "pending ops: {:?}", text_thread.pending_ops);
 992            assert_eq!(
 993                text_thread.buffer().read(cx).text(),
 994                first_context.buffer().read(cx).text(),
 995                "Context {:?} text != Context 0 text",
 996                text_thread.buffer().read(cx).replica_id()
 997            );
 998            assert_eq!(
 999                text_thread.message_anchors,
1000                first_context.message_anchors,
1001                "Context {:?} messages != Context 0 messages",
1002                text_thread.buffer().read(cx).replica_id()
1003            );
1004            assert_eq!(
1005                text_thread.messages_metadata,
1006                first_context.messages_metadata,
1007                "Context {:?} message metadata != Context 0 message metadata",
1008                text_thread.buffer().read(cx).replica_id()
1009            );
1010            assert_eq!(
1011                text_thread.slash_command_output_sections,
1012                first_context.slash_command_output_sections,
1013                "Context {:?} slash command output sections != Context 0 slash command output sections",
1014                text_thread.buffer().read(cx).replica_id()
1015            );
1016        }
1017    });
1018}
1019
1020#[gpui::test]
1021fn test_mark_cache_anchors(cx: &mut App) {
1022    init_test(cx);
1023
1024    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1025    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1026    let text_thread = cx.new(|cx| {
1027        TextThread::local(
1028            registry,
1029            prompt_builder.clone(),
1030            Arc::new(SlashCommandWorkingSet::default()),
1031            cx,
1032        )
1033    });
1034    let buffer = text_thread.read(cx).buffer().clone();
1035
1036    // Create a test cache configuration
1037    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1038        max_cache_anchors: 3,
1039        should_speculate: true,
1040        min_total_token: 10,
1041    });
1042
1043    let message_1 = text_thread.read(cx).message_anchors[0].clone();
1044
1045    text_thread.update(cx, |text_thread, cx| {
1046        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1047    });
1048
1049    assert_eq!(
1050        messages_cache(&text_thread, cx)
1051            .iter()
1052            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1053            .count(),
1054        0,
1055        "Empty messages should not have any cache anchors."
1056    );
1057
1058    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1059    let message_2 = text_thread
1060        .update(cx, |text_thread, cx| {
1061            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1062        })
1063        .unwrap();
1064
1065    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1066    let message_3 = text_thread
1067        .update(cx, |text_thread, cx| {
1068            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1069        })
1070        .unwrap();
1071    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1072
1073    text_thread.update(cx, |text_thread, cx| {
1074        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1075    });
1076    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1077    assert_eq!(
1078        messages_cache(&text_thread, cx)
1079            .iter()
1080            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1081            .count(),
1082        0,
1083        "Messages should not be marked for cache before going over the token minimum."
1084    );
1085    text_thread.update(cx, |text_thread, _| {
1086        text_thread.token_count = Some(20);
1087    });
1088
1089    text_thread.update(cx, |text_thread, cx| {
1090        text_thread.mark_cache_anchors(cache_configuration, true, cx)
1091    });
1092    assert_eq!(
1093        messages_cache(&text_thread, cx)
1094            .iter()
1095            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1096            .collect::<Vec<bool>>(),
1097        vec![true, true, false],
1098        "Last message should not be an anchor on speculative request."
1099    );
1100
1101    text_thread
1102        .update(cx, |text_thread, cx| {
1103            text_thread.insert_message_after(
1104                message_3.id,
1105                Role::Assistant,
1106                MessageStatus::Pending,
1107                cx,
1108            )
1109        })
1110        .unwrap();
1111
1112    text_thread.update(cx, |text_thread, cx| {
1113        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1114    });
1115    assert_eq!(
1116        messages_cache(&text_thread, cx)
1117            .iter()
1118            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1119            .collect::<Vec<bool>>(),
1120        vec![false, true, true, false],
1121        "Most recent message should also be cached if not a speculative request."
1122    );
1123    text_thread.update(cx, |text_thread, cx| {
1124        text_thread.update_cache_status_for_completion(cx)
1125    });
1126    assert_eq!(
1127        messages_cache(&text_thread, cx)
1128            .iter()
1129            .map(|(_, cache)| cache
1130                .as_ref()
1131                .map_or(None, |cache| Some(cache.status.clone())))
1132            .collect::<Vec<Option<CacheStatus>>>(),
1133        vec![
1134            Some(CacheStatus::Cached),
1135            Some(CacheStatus::Cached),
1136            Some(CacheStatus::Cached),
1137            None
1138        ],
1139        "All user messages prior to anchor should be marked as cached."
1140    );
1141
1142    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1143    text_thread.update(cx, |text_thread, cx| {
1144        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1145    });
1146    assert_eq!(
1147        messages_cache(&text_thread, cx)
1148            .iter()
1149            .map(|(_, cache)| cache
1150                .as_ref()
1151                .map_or(None, |cache| Some(cache.status.clone())))
1152            .collect::<Vec<Option<CacheStatus>>>(),
1153        vec![
1154            Some(CacheStatus::Cached),
1155            Some(CacheStatus::Cached),
1156            Some(CacheStatus::Pending),
1157            None
1158        ],
1159        "Modifying a message should invalidate it's cache but leave previous messages."
1160    );
1161    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1162    text_thread.update(cx, |text_thread, cx| {
1163        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1164    });
1165    assert_eq!(
1166        messages_cache(&text_thread, cx)
1167            .iter()
1168            .map(|(_, cache)| cache
1169                .as_ref()
1170                .map_or(None, |cache| Some(cache.status.clone())))
1171            .collect::<Vec<Option<CacheStatus>>>(),
1172        vec![
1173            Some(CacheStatus::Pending),
1174            Some(CacheStatus::Pending),
1175            Some(CacheStatus::Pending),
1176            None
1177        ],
1178        "Modifying a message should invalidate all future messages."
1179    );
1180}
1181
1182#[gpui::test]
1183async fn test_summarization(cx: &mut TestAppContext) {
1184    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1185
1186    // Initial state should be pending
1187    text_thread.read_with(cx, |text_thread, _| {
1188        assert!(matches!(text_thread.summary(), TextThreadSummary::Pending));
1189        assert_eq!(
1190            text_thread.summary().or_default(),
1191            TextThreadSummary::DEFAULT
1192        );
1193    });
1194
1195    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1196        text_thread.message_anchors[0].clone()
1197    });
1198    text_thread.update(cx, |context, cx| {
1199        context
1200            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1201            .unwrap();
1202    });
1203
1204    // Send a message
1205    text_thread.update(cx, |text_thread, cx| {
1206        text_thread.assist(cx);
1207    });
1208
1209    simulate_successful_response(&fake_model, cx);
1210
1211    // Should start generating summary when there are >= 2 messages
1212    text_thread.read_with(cx, |text_thread, _| {
1213        assert!(!text_thread.summary().content().unwrap().done);
1214    });
1215
1216    cx.run_until_parked();
1217    fake_model.send_last_completion_stream_text_chunk("Brief");
1218    fake_model.send_last_completion_stream_text_chunk(" Introduction");
1219    fake_model.end_last_completion_stream();
1220    cx.run_until_parked();
1221
1222    // Summary should be set
1223    text_thread.read_with(cx, |text_thread, _| {
1224        assert_eq!(text_thread.summary().or_default(), "Brief Introduction");
1225    });
1226
1227    // We should be able to manually set a summary
1228    text_thread.update(cx, |text_thread, cx| {
1229        text_thread.set_custom_summary("Brief Intro".into(), cx);
1230    });
1231
1232    text_thread.read_with(cx, |text_thread, _| {
1233        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1234    });
1235}
1236
1237#[gpui::test]
1238async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1239    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1240
1241    test_summarize_error(&fake_model, &text_thread, cx);
1242
1243    // Now we should be able to set a summary
1244    text_thread.update(cx, |text_thread, cx| {
1245        text_thread.set_custom_summary("Brief Intro".into(), cx);
1246    });
1247
1248    text_thread.read_with(cx, |text_thread, _| {
1249        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1250    });
1251}
1252
1253#[gpui::test]
1254async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1255    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1256
1257    test_summarize_error(&fake_model, &text_thread, cx);
1258
1259    // Sending another message should not trigger another summarize request
1260    text_thread.update(cx, |text_thread, cx| {
1261        text_thread.assist(cx);
1262    });
1263
1264    simulate_successful_response(&fake_model, cx);
1265
1266    text_thread.read_with(cx, |text_thread, _| {
1267        // State is still Error, not Generating
1268        assert!(matches!(text_thread.summary(), TextThreadSummary::Error));
1269    });
1270
1271    // But the summarize request can be invoked manually
1272    text_thread.update(cx, |text_thread, cx| {
1273        text_thread.summarize(true, cx);
1274    });
1275
1276    text_thread.read_with(cx, |text_thread, _| {
1277        assert!(!text_thread.summary().content().unwrap().done);
1278    });
1279
1280    cx.run_until_parked();
1281    fake_model.send_last_completion_stream_text_chunk("A successful summary");
1282    fake_model.end_last_completion_stream();
1283    cx.run_until_parked();
1284
1285    text_thread.read_with(cx, |text_thread, _| {
1286        assert_eq!(text_thread.summary().or_default(), "A successful summary");
1287    });
1288}
1289
1290fn test_summarize_error(
1291    model: &Arc<FakeLanguageModel>,
1292    text_thread: &Entity<TextThread>,
1293    cx: &mut TestAppContext,
1294) {
1295    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1296        text_thread.message_anchors[0].clone()
1297    });
1298    text_thread.update(cx, |text_thread, cx| {
1299        text_thread
1300            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1301            .unwrap();
1302    });
1303
1304    // Send a message
1305    text_thread.update(cx, |text_thread, cx| {
1306        text_thread.assist(cx);
1307    });
1308
1309    simulate_successful_response(model, cx);
1310
1311    text_thread.read_with(cx, |text_thread, _| {
1312        assert!(!text_thread.summary().content().unwrap().done);
1313    });
1314
1315    // Simulate summary request ending
1316    cx.run_until_parked();
1317    model.end_last_completion_stream();
1318    cx.run_until_parked();
1319
1320    // State is set to Error and default message
1321    text_thread.read_with(cx, |text_thread, _| {
1322        assert_eq!(*text_thread.summary(), TextThreadSummary::Error);
1323        assert_eq!(
1324            text_thread.summary().or_default(),
1325            TextThreadSummary::DEFAULT
1326        );
1327    });
1328}
1329
1330fn setup_context_editor_with_fake_model(
1331    cx: &mut TestAppContext,
1332) -> (Entity<TextThread>, Arc<FakeLanguageModel>) {
1333    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1334
1335    let fake_provider = Arc::new(FakeLanguageModelProvider::default());
1336    let fake_model = Arc::new(fake_provider.test_model());
1337
1338    cx.update(|cx| {
1339        init_test(cx);
1340        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1341            let configured_model = ConfiguredModel {
1342                provider: fake_provider.clone(),
1343                model: fake_model.clone(),
1344            };
1345            registry.set_default_model(Some(configured_model.clone()), cx);
1346            registry.set_thread_summary_model(Some(configured_model), cx);
1347        })
1348    });
1349
1350    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1351    let context = cx.new(|cx| {
1352        TextThread::local(
1353            registry,
1354            prompt_builder.clone(),
1355            Arc::new(SlashCommandWorkingSet::default()),
1356            cx,
1357        )
1358    });
1359
1360    (context, fake_model)
1361}
1362
1363fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1364    cx.run_until_parked();
1365    fake_model.send_last_completion_stream_text_chunk("Assistant response");
1366    fake_model.end_last_completion_stream();
1367    cx.run_until_parked();
1368}
1369
1370fn messages(context: &Entity<TextThread>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1371    context
1372        .read(cx)
1373        .messages(cx)
1374        .map(|message| (message.id, message.role, message.offset_range))
1375        .collect()
1376}
1377
1378fn messages_cache(
1379    context: &Entity<TextThread>,
1380    cx: &App,
1381) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1382    context
1383        .read(cx)
1384        .messages(cx)
1385        .map(|message| (message.id, message.cache))
1386        .collect()
1387}
1388
1389fn init_test(cx: &mut App) {
1390    let settings_store = SettingsStore::test(cx);
1391    prompt_store::init(cx);
1392    LanguageModelRegistry::test(cx);
1393    cx.set_global(settings_store);
1394}
1395
1396#[derive(Clone)]
1397struct FakeSlashCommand(String);
1398
1399impl SlashCommand for FakeSlashCommand {
1400    fn name(&self) -> String {
1401        self.0.clone()
1402    }
1403
1404    fn description(&self) -> String {
1405        format!("Fake slash command: {}", self.0)
1406    }
1407
1408    fn menu_text(&self) -> String {
1409        format!("Run fake command: {}", self.0)
1410    }
1411
1412    fn complete_argument(
1413        self: Arc<Self>,
1414        _arguments: &[String],
1415        _cancel: Arc<AtomicBool>,
1416        _workspace: Option<WeakEntity<Workspace>>,
1417        _window: &mut Window,
1418        _cx: &mut App,
1419    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1420        Task::ready(Ok(vec![]))
1421    }
1422
1423    fn requires_argument(&self) -> bool {
1424        false
1425    }
1426
1427    fn run(
1428        self: Arc<Self>,
1429        _arguments: &[String],
1430        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1431        _context_buffer: BufferSnapshot,
1432        _workspace: WeakEntity<Workspace>,
1433        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1434        _window: &mut Window,
1435        _cx: &mut App,
1436    ) -> Task<SlashCommandResult> {
1437        Task::ready(Ok(SlashCommandOutput {
1438            text: format!("Executed fake command: {}", self.0),
1439            sections: vec![],
1440            run_commands_in_text: false,
1441        }
1442        .into_event_stream()))
1443    }
1444}