assistant_text_thread_tests.rs

   1use crate::{
   2    CacheStatus, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, TextThread,
   3    TextThreadEvent, TextThreadId, TextThreadOperation, TextThreadSummary,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{
  20    ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
  21    fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
  22};
  23use parking_lot::Mutex;
  24use pretty_assertions::assert_eq;
  25use prompt_store::PromptBuilder;
  26use rand::prelude::*;
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{
  30    cell::RefCell,
  31    env,
  32    ops::Range,
  33    path::Path,
  34    rc::Rc,
  35    sync::{Arc, atomic::AtomicBool},
  36};
  37use text::{ReplicaId, ToOffset, network::Network};
  38use ui::{IconName, Window};
  39use unindent::Unindent;
  40use util::RandomCharIter;
  41use workspace::Workspace;
  42
  43#[gpui::test]
  44fn test_inserting_and_removing_messages(cx: &mut App) {
  45    init_test(cx);
  46
  47    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  48    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  49    let text_thread = cx.new(|cx| {
  50        TextThread::local(
  51            registry,
  52            None,
  53            None,
  54            prompt_builder.clone(),
  55            Arc::new(SlashCommandWorkingSet::default()),
  56            cx,
  57        )
  58    });
  59    let buffer = text_thread.read(cx).buffer().clone();
  60
  61    let message_1 = text_thread.read(cx).message_anchors[0].clone();
  62    assert_eq!(
  63        messages(&text_thread, cx),
  64        vec![(message_1.id, Role::User, 0..0)]
  65    );
  66
  67    let message_2 = text_thread.update(cx, |context, cx| {
  68        context
  69            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  70            .unwrap()
  71    });
  72    assert_eq!(
  73        messages(&text_thread, cx),
  74        vec![
  75            (message_1.id, Role::User, 0..1),
  76            (message_2.id, Role::Assistant, 1..1)
  77        ]
  78    );
  79
  80    buffer.update(cx, |buffer, cx| {
  81        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  82    });
  83    assert_eq!(
  84        messages(&text_thread, cx),
  85        vec![
  86            (message_1.id, Role::User, 0..2),
  87            (message_2.id, Role::Assistant, 2..3)
  88        ]
  89    );
  90
  91    let message_3 = text_thread.update(cx, |context, cx| {
  92        context
  93            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  94            .unwrap()
  95    });
  96    assert_eq!(
  97        messages(&text_thread, cx),
  98        vec![
  99            (message_1.id, Role::User, 0..2),
 100            (message_2.id, Role::Assistant, 2..4),
 101            (message_3.id, Role::User, 4..4)
 102        ]
 103    );
 104
 105    let message_4 = text_thread.update(cx, |context, cx| {
 106        context
 107            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 108            .unwrap()
 109    });
 110    assert_eq!(
 111        messages(&text_thread, cx),
 112        vec![
 113            (message_1.id, Role::User, 0..2),
 114            (message_2.id, Role::Assistant, 2..4),
 115            (message_4.id, Role::User, 4..5),
 116            (message_3.id, Role::User, 5..5),
 117        ]
 118    );
 119
 120    buffer.update(cx, |buffer, cx| {
 121        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 122    });
 123    assert_eq!(
 124        messages(&text_thread, cx),
 125        vec![
 126            (message_1.id, Role::User, 0..2),
 127            (message_2.id, Role::Assistant, 2..4),
 128            (message_4.id, Role::User, 4..6),
 129            (message_3.id, Role::User, 6..7),
 130        ]
 131    );
 132
 133    // Deleting across message boundaries merges the messages.
 134    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 135    assert_eq!(
 136        messages(&text_thread, cx),
 137        vec![
 138            (message_1.id, Role::User, 0..3),
 139            (message_3.id, Role::User, 3..4),
 140        ]
 141    );
 142
 143    // Undoing the deletion should also undo the merge.
 144    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 145    assert_eq!(
 146        messages(&text_thread, cx),
 147        vec![
 148            (message_1.id, Role::User, 0..2),
 149            (message_2.id, Role::Assistant, 2..4),
 150            (message_4.id, Role::User, 4..6),
 151            (message_3.id, Role::User, 6..7),
 152        ]
 153    );
 154
 155    // Redoing the deletion should also redo the merge.
 156    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 157    assert_eq!(
 158        messages(&text_thread, cx),
 159        vec![
 160            (message_1.id, Role::User, 0..3),
 161            (message_3.id, Role::User, 3..4),
 162        ]
 163    );
 164
 165    // Ensure we can still insert after a merged message.
 166    let message_5 = text_thread.update(cx, |context, cx| {
 167        context
 168            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 169            .unwrap()
 170    });
 171    assert_eq!(
 172        messages(&text_thread, cx),
 173        vec![
 174            (message_1.id, Role::User, 0..3),
 175            (message_5.id, Role::System, 3..4),
 176            (message_3.id, Role::User, 4..5)
 177        ]
 178    );
 179}
 180
 181#[gpui::test]
 182fn test_message_splitting(cx: &mut App) {
 183    init_test(cx);
 184
 185    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 186
 187    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 188    let text_thread = cx.new(|cx| {
 189        TextThread::local(
 190            registry.clone(),
 191            None,
 192            None,
 193            prompt_builder.clone(),
 194            Arc::new(SlashCommandWorkingSet::default()),
 195            cx,
 196        )
 197    });
 198    let buffer = text_thread.read(cx).buffer().clone();
 199
 200    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 201    assert_eq!(
 202        messages(&text_thread, cx),
 203        vec![(message_1.id, Role::User, 0..0)]
 204    );
 205
 206    buffer.update(cx, |buffer, cx| {
 207        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 208    });
 209
 210    let (_, message_2) =
 211        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 212    let message_2 = message_2.unwrap();
 213
 214    // We recycle newlines in the middle of a split message
 215    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 216    assert_eq!(
 217        messages(&text_thread, cx),
 218        vec![
 219            (message_1.id, Role::User, 0..4),
 220            (message_2.id, Role::User, 4..16),
 221        ]
 222    );
 223
 224    let (_, message_3) =
 225        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 226    let message_3 = message_3.unwrap();
 227
 228    // We don't recycle newlines at the end of a split message
 229    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 230    assert_eq!(
 231        messages(&text_thread, cx),
 232        vec![
 233            (message_1.id, Role::User, 0..4),
 234            (message_3.id, Role::User, 4..5),
 235            (message_2.id, Role::User, 5..17),
 236        ]
 237    );
 238
 239    let (_, message_4) =
 240        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 241    let message_4 = message_4.unwrap();
 242    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 243    assert_eq!(
 244        messages(&text_thread, cx),
 245        vec![
 246            (message_1.id, Role::User, 0..4),
 247            (message_3.id, Role::User, 4..5),
 248            (message_2.id, Role::User, 5..9),
 249            (message_4.id, Role::User, 9..17),
 250        ]
 251    );
 252
 253    let (_, message_5) =
 254        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 255    let message_5 = message_5.unwrap();
 256    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 257    assert_eq!(
 258        messages(&text_thread, cx),
 259        vec![
 260            (message_1.id, Role::User, 0..4),
 261            (message_3.id, Role::User, 4..5),
 262            (message_2.id, Role::User, 5..9),
 263            (message_4.id, Role::User, 9..10),
 264            (message_5.id, Role::User, 10..18),
 265        ]
 266    );
 267
 268    let (message_6, message_7) =
 269        text_thread.update(cx, |text_thread, cx| text_thread.split_message(14..16, cx));
 270    let message_6 = message_6.unwrap();
 271    let message_7 = message_7.unwrap();
 272    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 273    assert_eq!(
 274        messages(&text_thread, cx),
 275        vec![
 276            (message_1.id, Role::User, 0..4),
 277            (message_3.id, Role::User, 4..5),
 278            (message_2.id, Role::User, 5..9),
 279            (message_4.id, Role::User, 9..10),
 280            (message_5.id, Role::User, 10..14),
 281            (message_6.id, Role::User, 14..17),
 282            (message_7.id, Role::User, 17..19),
 283        ]
 284    );
 285}
 286
 287#[gpui::test]
 288fn test_messages_for_offsets(cx: &mut App) {
 289    init_test(cx);
 290
 291    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 292    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 293    let text_thread = cx.new(|cx| {
 294        TextThread::local(
 295            registry,
 296            None,
 297            None,
 298            prompt_builder.clone(),
 299            Arc::new(SlashCommandWorkingSet::default()),
 300            cx,
 301        )
 302    });
 303    let buffer = text_thread.read(cx).buffer().clone();
 304
 305    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 306    assert_eq!(
 307        messages(&text_thread, cx),
 308        vec![(message_1.id, Role::User, 0..0)]
 309    );
 310
 311    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 312    let message_2 = text_thread
 313        .update(cx, |text_thread, cx| {
 314            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 315        })
 316        .unwrap();
 317    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 318
 319    let message_3 = text_thread
 320        .update(cx, |text_thread, cx| {
 321            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 322        })
 323        .unwrap();
 324    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 325
 326    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 327    assert_eq!(
 328        messages(&text_thread, cx),
 329        vec![
 330            (message_1.id, Role::User, 0..4),
 331            (message_2.id, Role::User, 4..8),
 332            (message_3.id, Role::User, 8..11)
 333        ]
 334    );
 335
 336    assert_eq!(
 337        message_ids_for_offsets(&text_thread, &[0, 4, 9], cx),
 338        [message_1.id, message_2.id, message_3.id]
 339    );
 340    assert_eq!(
 341        message_ids_for_offsets(&text_thread, &[0, 1, 11], cx),
 342        [message_1.id, message_3.id]
 343    );
 344
 345    let message_4 = text_thread
 346        .update(cx, |text_thread, cx| {
 347            text_thread.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 348        })
 349        .unwrap();
 350    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 351    assert_eq!(
 352        messages(&text_thread, cx),
 353        vec![
 354            (message_1.id, Role::User, 0..4),
 355            (message_2.id, Role::User, 4..8),
 356            (message_3.id, Role::User, 8..12),
 357            (message_4.id, Role::User, 12..12)
 358        ]
 359    );
 360    assert_eq!(
 361        message_ids_for_offsets(&text_thread, &[0, 4, 8, 12], cx),
 362        [message_1.id, message_2.id, message_3.id, message_4.id]
 363    );
 364
 365    fn message_ids_for_offsets(
 366        context: &Entity<TextThread>,
 367        offsets: &[usize],
 368        cx: &App,
 369    ) -> Vec<MessageId> {
 370        context
 371            .read(cx)
 372            .messages_for_offsets(offsets.iter().copied(), cx)
 373            .into_iter()
 374            .map(|message| message.id)
 375            .collect()
 376    }
 377}
 378
 379#[gpui::test]
 380async fn test_slash_commands(cx: &mut TestAppContext) {
 381    cx.update(init_test);
 382
 383    let fs = FakeFs::new(cx.background_executor.clone());
 384
 385    fs.insert_tree(
 386        "/test",
 387        json!({
 388            "src": {
 389                "lib.rs": "fn one() -> usize { 1 }",
 390                "main.rs": "
 391                    use crate::one;
 392                    fn main() { one(); }
 393                ".unindent(),
 394            }
 395        }),
 396    )
 397    .await;
 398
 399    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 400    slash_command_registry.register_command(FileSlashCommand, false);
 401
 402    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 403    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 404    let text_thread = cx.new(|cx| {
 405        TextThread::local(
 406            registry.clone(),
 407            None,
 408            None,
 409            prompt_builder.clone(),
 410            Arc::new(SlashCommandWorkingSet::default()),
 411            cx,
 412        )
 413    });
 414
 415    #[derive(Default)]
 416    struct ContextRanges {
 417        parsed_commands: HashSet<Range<language::Anchor>>,
 418        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 419        output_sections: HashSet<Range<language::Anchor>>,
 420    }
 421
 422    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 423    text_thread.update(cx, |_, cx| {
 424        cx.subscribe(&text_thread, {
 425            let context_ranges = context_ranges.clone();
 426            move |text_thread, _, event, _| {
 427                let mut context_ranges = context_ranges.borrow_mut();
 428                match event {
 429                    TextThreadEvent::InvokedSlashCommandChanged { command_id } => {
 430                        let command = text_thread.invoked_slash_command(command_id).unwrap();
 431                        context_ranges
 432                            .command_outputs
 433                            .insert(*command_id, command.range.clone());
 434                    }
 435                    TextThreadEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 436                        for range in removed {
 437                            context_ranges.parsed_commands.remove(range);
 438                        }
 439                        for command in updated {
 440                            context_ranges
 441                                .parsed_commands
 442                                .insert(command.source_range.clone());
 443                        }
 444                    }
 445                    TextThreadEvent::SlashCommandOutputSectionAdded { section } => {
 446                        context_ranges.output_sections.insert(section.range.clone());
 447                    }
 448                    _ => {}
 449                }
 450            }
 451        })
 452        .detach();
 453    });
 454
 455    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 456
 457    // Insert a slash command
 458    buffer.update(cx, |buffer, cx| {
 459        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 460    });
 461    assert_text_and_context_ranges(
 462        &buffer,
 463        &context_ranges,
 464        &"
 465        «/file src/lib.rs»"
 466            .unindent(),
 467        cx,
 468    );
 469
 470    // Edit the argument of the slash command.
 471    buffer.update(cx, |buffer, cx| {
 472        let edit_offset = buffer.text().find("lib.rs").unwrap();
 473        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 474    });
 475    assert_text_and_context_ranges(
 476        &buffer,
 477        &context_ranges,
 478        &"
 479        «/file src/main.rs»"
 480            .unindent(),
 481        cx,
 482    );
 483
 484    // Edit the name of the slash command, using one that doesn't exist.
 485    buffer.update(cx, |buffer, cx| {
 486        let edit_offset = buffer.text().find("/file").unwrap();
 487        buffer.edit(
 488            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 489            None,
 490            cx,
 491        );
 492    });
 493    assert_text_and_context_ranges(
 494        &buffer,
 495        &context_ranges,
 496        &"
 497        /unknown src/main.rs"
 498            .unindent(),
 499        cx,
 500    );
 501
 502    // Undoing the insertion of an non-existent slash command resorts the previous one.
 503    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 504    assert_text_and_context_ranges(
 505        &buffer,
 506        &context_ranges,
 507        &"
 508        «/file src/main.rs»"
 509            .unindent(),
 510        cx,
 511    );
 512
 513    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 514    text_thread.update(cx, |text_thread, cx| {
 515        let command_source_range = text_thread.parsed_slash_commands[0].source_range.clone();
 516        text_thread.insert_command_output(
 517            command_source_range,
 518            "file",
 519            Task::ready(Ok(command_output_rx.boxed())),
 520            true,
 521            cx,
 522        );
 523    });
 524    assert_text_and_context_ranges(
 525        &buffer,
 526        &context_ranges,
 527        &"
 528        ⟦«/file src/main.rs»
 529        …⟧
 530        "
 531        .unindent(),
 532        cx,
 533    );
 534
 535    command_output_tx
 536        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 537            icon: IconName::Ai,
 538            label: "src/main.rs".into(),
 539            metadata: None,
 540        }))
 541        .unwrap();
 542    command_output_tx
 543        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 544        .unwrap();
 545    cx.run_until_parked();
 546    assert_text_and_context_ranges(
 547        &buffer,
 548        &context_ranges,
 549        &"
 550        ⟦«/file src/main.rs»
 551        src/main.rs…⟧
 552        "
 553        .unindent(),
 554        cx,
 555    );
 556
 557    command_output_tx
 558        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 559        .unwrap();
 560    cx.run_until_parked();
 561    assert_text_and_context_ranges(
 562        &buffer,
 563        &context_ranges,
 564        &"
 565        ⟦«/file src/main.rs»
 566        src/main.rs
 567        fn main() {}…⟧
 568        "
 569        .unindent(),
 570        cx,
 571    );
 572
 573    command_output_tx
 574        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 575        .unwrap();
 576    cx.run_until_parked();
 577    assert_text_and_context_ranges(
 578        &buffer,
 579        &context_ranges,
 580        &"
 581        ⟦«/file src/main.rs»
 582        ⟪src/main.rs
 583        fn main() {}⟫…⟧
 584        "
 585        .unindent(),
 586        cx,
 587    );
 588
 589    drop(command_output_tx);
 590    cx.run_until_parked();
 591    assert_text_and_context_ranges(
 592        &buffer,
 593        &context_ranges,
 594        &"
 595        ⟦⟪src/main.rs
 596        fn main() {}⟫⟧
 597        "
 598        .unindent(),
 599        cx,
 600    );
 601
 602    #[track_caller]
 603    fn assert_text_and_context_ranges(
 604        buffer: &Entity<Buffer>,
 605        ranges: &RefCell<ContextRanges>,
 606        expected_marked_text: &str,
 607        cx: &mut TestAppContext,
 608    ) {
 609        let mut actual_marked_text = String::new();
 610        buffer.update(cx, |buffer, _| {
 611            struct Endpoint {
 612                offset: usize,
 613                marker: char,
 614            }
 615
 616            let ranges = ranges.borrow();
 617            let mut endpoints = Vec::new();
 618            for range in ranges.command_outputs.values() {
 619                endpoints.push(Endpoint {
 620                    offset: range.start.to_offset(buffer),
 621                    marker: '',
 622                });
 623            }
 624            for range in ranges.parsed_commands.iter() {
 625                endpoints.push(Endpoint {
 626                    offset: range.start.to_offset(buffer),
 627                    marker: '«',
 628                });
 629            }
 630            for range in ranges.output_sections.iter() {
 631                endpoints.push(Endpoint {
 632                    offset: range.start.to_offset(buffer),
 633                    marker: '',
 634                });
 635            }
 636
 637            for range in ranges.output_sections.iter() {
 638                endpoints.push(Endpoint {
 639                    offset: range.end.to_offset(buffer),
 640                    marker: '',
 641                });
 642            }
 643            for range in ranges.parsed_commands.iter() {
 644                endpoints.push(Endpoint {
 645                    offset: range.end.to_offset(buffer),
 646                    marker: '»',
 647                });
 648            }
 649            for range in ranges.command_outputs.values() {
 650                endpoints.push(Endpoint {
 651                    offset: range.end.to_offset(buffer),
 652                    marker: '',
 653                });
 654            }
 655
 656            endpoints.sort_by_key(|endpoint| endpoint.offset);
 657            let mut offset = 0;
 658            for endpoint in endpoints {
 659                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 660                actual_marked_text.push(endpoint.marker);
 661                offset = endpoint.offset;
 662            }
 663            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 664        });
 665
 666        assert_eq!(actual_marked_text, expected_marked_text);
 667    }
 668}
 669
 670#[gpui::test]
 671async fn test_serialization(cx: &mut TestAppContext) {
 672    cx.update(init_test);
 673
 674    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 675    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 676    let text_thread = cx.new(|cx| {
 677        TextThread::local(
 678            registry.clone(),
 679            None,
 680            None,
 681            prompt_builder.clone(),
 682            Arc::new(SlashCommandWorkingSet::default()),
 683            cx,
 684        )
 685    });
 686    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 687    let message_0 = text_thread.read_with(cx, |text_thread, _| text_thread.message_anchors[0].id);
 688    let message_1 = text_thread.update(cx, |text_thread, cx| {
 689        text_thread
 690            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 691            .unwrap()
 692    });
 693    let message_2 = text_thread.update(cx, |text_thread, cx| {
 694        text_thread
 695            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 696            .unwrap()
 697    });
 698    buffer.update(cx, |buffer, cx| {
 699        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 700        buffer.finalize_last_transaction();
 701    });
 702    let _message_3 = text_thread.update(cx, |text_thread, cx| {
 703        text_thread
 704            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 705            .unwrap()
 706    });
 707    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 708    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 709    assert_eq!(
 710        cx.read(|cx| messages(&text_thread, cx)),
 711        [
 712            (message_0, Role::User, 0..2),
 713            (message_1.id, Role::Assistant, 2..6),
 714            (message_2.id, Role::System, 6..6),
 715        ]
 716    );
 717
 718    let serialized_context = text_thread.read_with(cx, |text_thread, cx| text_thread.serialize(cx));
 719    let deserialized_context = cx.new(|cx| {
 720        TextThread::deserialize(
 721            serialized_context,
 722            Path::new("").into(),
 723            registry.clone(),
 724            prompt_builder.clone(),
 725            Arc::new(SlashCommandWorkingSet::default()),
 726            None,
 727            None,
 728            cx,
 729        )
 730    });
 731    let deserialized_buffer =
 732        deserialized_context.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 733    assert_eq!(
 734        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 735        "a\nb\nc\n"
 736    );
 737    assert_eq!(
 738        cx.read(|cx| messages(&deserialized_context, cx)),
 739        [
 740            (message_0, Role::User, 0..2),
 741            (message_1.id, Role::Assistant, 2..6),
 742            (message_2.id, Role::System, 6..6),
 743        ]
 744    );
 745}
 746
 747#[gpui::test(iterations = 25)]
 748async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 749    cx.update(init_test);
 750
 751    let min_peers = env::var("MIN_PEERS")
 752        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 753        .unwrap_or(2);
 754    let max_peers = env::var("MAX_PEERS")
 755        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 756        .unwrap_or(5);
 757    let operations = env::var("OPERATIONS")
 758        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 759        .unwrap_or(50);
 760
 761    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 762    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 763    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 764    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 765
 766    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 767    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 768    let mut text_threads = Vec::new();
 769
 770    let num_peers = rng.random_range(min_peers..=max_peers);
 771    let context_id = TextThreadId::new();
 772    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 773    for i in 0..num_peers {
 774        let context = cx.new(|cx| {
 775            TextThread::new(
 776                context_id.clone(),
 777                ReplicaId::new(i as u16),
 778                language::Capability::ReadWrite,
 779                registry.clone(),
 780                prompt_builder.clone(),
 781                Arc::new(SlashCommandWorkingSet::default()),
 782                None,
 783                None,
 784                cx,
 785            )
 786        });
 787
 788        cx.update(|cx| {
 789            cx.subscribe(&context, {
 790                let network = network.clone();
 791                move |_, event, _| {
 792                    if let TextThreadEvent::Operation(op) = event {
 793                        network
 794                            .lock()
 795                            .broadcast(ReplicaId::new(i as u16), vec![op.to_proto()]);
 796                    }
 797                }
 798            })
 799            .detach();
 800        });
 801
 802        text_threads.push(context);
 803        network.lock().add_peer(ReplicaId::new(i as u16));
 804    }
 805
 806    let mut mutation_count = operations;
 807
 808    while mutation_count > 0
 809        || !network.lock().is_idle()
 810        || network.lock().contains_disconnected_peers()
 811    {
 812        let context_index = rng.random_range(0..text_threads.len());
 813        let text_thread = &text_threads[context_index];
 814
 815        match rng.random_range(0..100) {
 816            0..=29 if mutation_count > 0 => {
 817                log::info!("Context {}: edit buffer", context_index);
 818                text_thread.update(cx, |text_thread, cx| {
 819                    text_thread
 820                        .buffer()
 821                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
 822                });
 823                mutation_count -= 1;
 824            }
 825            30..=44 if mutation_count > 0 => {
 826                text_thread.update(cx, |text_thread, cx| {
 827                    let range = text_thread.buffer().read(cx).random_byte_range(0, &mut rng);
 828                    log::info!("Context {}: split message at {:?}", context_index, range);
 829                    text_thread.split_message(range, cx);
 830                });
 831                mutation_count -= 1;
 832            }
 833            45..=59 if mutation_count > 0 => {
 834                text_thread.update(cx, |text_thread, cx| {
 835                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 836                        let role = *[Role::User, Role::Assistant, Role::System]
 837                            .choose(&mut rng)
 838                            .unwrap();
 839                        log::info!(
 840                            "Context {}: insert message after {:?} with {:?}",
 841                            context_index,
 842                            message.id,
 843                            role
 844                        );
 845                        text_thread.insert_message_after(message.id, role, MessageStatus::Done, cx);
 846                    }
 847                });
 848                mutation_count -= 1;
 849            }
 850            60..=74 if mutation_count > 0 => {
 851                text_thread.update(cx, |text_thread, cx| {
 852                    let command_text = "/".to_string()
 853                        + slash_commands
 854                            .command_names()
 855                            .choose(&mut rng)
 856                            .unwrap()
 857                            .clone()
 858                            .as_ref();
 859
 860                    let command_range = text_thread.buffer().update(cx, |buffer, cx| {
 861                        let offset = buffer.random_byte_range(0, &mut rng).start;
 862                        buffer.edit(
 863                            [(offset..offset, format!("\n{}\n", command_text))],
 864                            None,
 865                            cx,
 866                        );
 867                        offset + 1..offset + 1 + command_text.len()
 868                    });
 869
 870                    let output_text = RandomCharIter::new(&mut rng)
 871                        .filter(|c| *c != '\r')
 872                        .take(10)
 873                        .collect::<String>();
 874
 875                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
 876                        role: Role::User,
 877                        merge_same_roles: true,
 878                    })];
 879
 880                    let num_sections = rng.random_range(0..=3);
 881                    let mut section_start = 0;
 882                    for _ in 0..num_sections {
 883                        let section_end = output_text.floor_char_boundary(
 884                            rng.random_range(section_start..=output_text.len()),
 885                        );
 886                        events.push(Ok(SlashCommandEvent::StartSection {
 887                            icon: IconName::Ai,
 888                            label: "section".into(),
 889                            metadata: None,
 890                        }));
 891                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 892                            text: output_text[section_start..section_end].to_string(),
 893                            run_commands_in_text: false,
 894                        })));
 895                        events.push(Ok(SlashCommandEvent::EndSection));
 896                        section_start = section_end;
 897                    }
 898
 899                    if section_start < output_text.len() {
 900                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 901                            text: output_text[section_start..].to_string(),
 902                            run_commands_in_text: false,
 903                        })));
 904                    }
 905
 906                    log::info!(
 907                        "Context {}: insert slash command output at {:?} with {:?} events",
 908                        context_index,
 909                        command_range,
 910                        events.len()
 911                    );
 912
 913                    let command_range = text_thread
 914                        .buffer()
 915                        .read(cx)
 916                        .anchor_after(command_range.start)
 917                        ..text_thread
 918                            .buffer()
 919                            .read(cx)
 920                            .anchor_after(command_range.end);
 921                    text_thread.insert_command_output(
 922                        command_range,
 923                        "/command",
 924                        Task::ready(Ok(stream::iter(events).boxed())),
 925                        true,
 926                        cx,
 927                    );
 928                });
 929                cx.run_until_parked();
 930                mutation_count -= 1;
 931            }
 932            75..=84 if mutation_count > 0 => {
 933                text_thread.update(cx, |text_thread, cx| {
 934                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 935                        let new_status = match rng.random_range(0..3) {
 936                            0 => MessageStatus::Done,
 937                            1 => MessageStatus::Pending,
 938                            _ => MessageStatus::Error(SharedString::from("Random error")),
 939                        };
 940                        log::info!(
 941                            "Context {}: update message {:?} status to {:?}",
 942                            context_index,
 943                            message.id,
 944                            new_status
 945                        );
 946                        text_thread.update_metadata(message.id, cx, |metadata| {
 947                            metadata.status = new_status;
 948                        });
 949                    }
 950                });
 951                mutation_count -= 1;
 952            }
 953            _ => {
 954                let replica_id = ReplicaId::new(context_index as u16);
 955                if network.lock().is_disconnected(replica_id) {
 956                    network.lock().reconnect_peer(replica_id, ReplicaId::new(0));
 957
 958                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
 959                        let host_context = &text_threads[0].read(cx);
 960                        let guest_context = text_thread.read(cx);
 961                        (
 962                            guest_context.serialize_ops(&host_context.version(cx), cx),
 963                            host_context.serialize_ops(&guest_context.version(cx), cx),
 964                        )
 965                    });
 966                    let ops_to_send = ops_to_send.await;
 967                    let ops_to_receive = ops_to_receive
 968                        .await
 969                        .into_iter()
 970                        .map(TextThreadOperation::from_proto)
 971                        .collect::<Result<Vec<_>>>()
 972                        .unwrap();
 973                    log::info!(
 974                        "Context {}: reconnecting. Sent {} operations, received {} operations",
 975                        context_index,
 976                        ops_to_send.len(),
 977                        ops_to_receive.len()
 978                    );
 979
 980                    network.lock().broadcast(replica_id, ops_to_send);
 981                    text_thread.update(cx, |text_thread, cx| {
 982                        text_thread.apply_ops(ops_to_receive, cx)
 983                    });
 984                } else if rng.random_bool(0.1) && replica_id != ReplicaId::new(0) {
 985                    log::info!("Context {}: disconnecting", context_index);
 986                    network.lock().disconnect_peer(replica_id);
 987                } else if network.lock().has_unreceived(replica_id) {
 988                    log::info!("Context {}: applying operations", context_index);
 989                    let ops = network.lock().receive(replica_id);
 990                    let ops = ops
 991                        .into_iter()
 992                        .map(TextThreadOperation::from_proto)
 993                        .collect::<Result<Vec<_>>>()
 994                        .unwrap();
 995                    text_thread.update(cx, |text_thread, cx| text_thread.apply_ops(ops, cx));
 996                }
 997            }
 998        }
 999    }
1000
1001    cx.read(|cx| {
1002        let first_context = text_threads[0].read(cx);
1003        for text_thread in &text_threads[1..] {
1004            let text_thread = text_thread.read(cx);
1005            assert!(text_thread.pending_ops.is_empty(), "pending ops: {:?}", text_thread.pending_ops);
1006            assert_eq!(
1007                text_thread.buffer().read(cx).text(),
1008                first_context.buffer().read(cx).text(),
1009                "Context {:?} text != Context 0 text",
1010                text_thread.buffer().read(cx).replica_id()
1011            );
1012            assert_eq!(
1013                text_thread.message_anchors,
1014                first_context.message_anchors,
1015                "Context {:?} messages != Context 0 messages",
1016                text_thread.buffer().read(cx).replica_id()
1017            );
1018            assert_eq!(
1019                text_thread.messages_metadata,
1020                first_context.messages_metadata,
1021                "Context {:?} message metadata != Context 0 message metadata",
1022                text_thread.buffer().read(cx).replica_id()
1023            );
1024            assert_eq!(
1025                text_thread.slash_command_output_sections,
1026                first_context.slash_command_output_sections,
1027                "Context {:?} slash command output sections != Context 0 slash command output sections",
1028                text_thread.buffer().read(cx).replica_id()
1029            );
1030        }
1031    });
1032}
1033
1034#[gpui::test]
1035fn test_mark_cache_anchors(cx: &mut App) {
1036    init_test(cx);
1037
1038    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1039    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1040    let text_thread = cx.new(|cx| {
1041        TextThread::local(
1042            registry,
1043            None,
1044            None,
1045            prompt_builder.clone(),
1046            Arc::new(SlashCommandWorkingSet::default()),
1047            cx,
1048        )
1049    });
1050    let buffer = text_thread.read(cx).buffer().clone();
1051
1052    // Create a test cache configuration
1053    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1054        max_cache_anchors: 3,
1055        should_speculate: true,
1056        min_total_token: 10,
1057    });
1058
1059    let message_1 = text_thread.read(cx).message_anchors[0].clone();
1060
1061    text_thread.update(cx, |text_thread, cx| {
1062        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1063    });
1064
1065    assert_eq!(
1066        messages_cache(&text_thread, cx)
1067            .iter()
1068            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1069            .count(),
1070        0,
1071        "Empty messages should not have any cache anchors."
1072    );
1073
1074    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1075    let message_2 = text_thread
1076        .update(cx, |text_thread, cx| {
1077            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1078        })
1079        .unwrap();
1080
1081    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1082    let message_3 = text_thread
1083        .update(cx, |text_thread, cx| {
1084            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1085        })
1086        .unwrap();
1087    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1088
1089    text_thread.update(cx, |text_thread, cx| {
1090        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1091    });
1092    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1093    assert_eq!(
1094        messages_cache(&text_thread, cx)
1095            .iter()
1096            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1097            .count(),
1098        0,
1099        "Messages should not be marked for cache before going over the token minimum."
1100    );
1101    text_thread.update(cx, |text_thread, _| {
1102        text_thread.token_count = Some(20);
1103    });
1104
1105    text_thread.update(cx, |text_thread, cx| {
1106        text_thread.mark_cache_anchors(cache_configuration, true, cx)
1107    });
1108    assert_eq!(
1109        messages_cache(&text_thread, cx)
1110            .iter()
1111            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1112            .collect::<Vec<bool>>(),
1113        vec![true, true, false],
1114        "Last message should not be an anchor on speculative request."
1115    );
1116
1117    text_thread
1118        .update(cx, |text_thread, cx| {
1119            text_thread.insert_message_after(
1120                message_3.id,
1121                Role::Assistant,
1122                MessageStatus::Pending,
1123                cx,
1124            )
1125        })
1126        .unwrap();
1127
1128    text_thread.update(cx, |text_thread, cx| {
1129        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1130    });
1131    assert_eq!(
1132        messages_cache(&text_thread, cx)
1133            .iter()
1134            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1135            .collect::<Vec<bool>>(),
1136        vec![false, true, true, false],
1137        "Most recent message should also be cached if not a speculative request."
1138    );
1139    text_thread.update(cx, |text_thread, cx| {
1140        text_thread.update_cache_status_for_completion(cx)
1141    });
1142    assert_eq!(
1143        messages_cache(&text_thread, cx)
1144            .iter()
1145            .map(|(_, cache)| cache
1146                .as_ref()
1147                .map_or(None, |cache| Some(cache.status.clone())))
1148            .collect::<Vec<Option<CacheStatus>>>(),
1149        vec![
1150            Some(CacheStatus::Cached),
1151            Some(CacheStatus::Cached),
1152            Some(CacheStatus::Cached),
1153            None
1154        ],
1155        "All user messages prior to anchor should be marked as cached."
1156    );
1157
1158    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1159    text_thread.update(cx, |text_thread, cx| {
1160        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1161    });
1162    assert_eq!(
1163        messages_cache(&text_thread, cx)
1164            .iter()
1165            .map(|(_, cache)| cache
1166                .as_ref()
1167                .map_or(None, |cache| Some(cache.status.clone())))
1168            .collect::<Vec<Option<CacheStatus>>>(),
1169        vec![
1170            Some(CacheStatus::Cached),
1171            Some(CacheStatus::Cached),
1172            Some(CacheStatus::Pending),
1173            None
1174        ],
1175        "Modifying a message should invalidate it's cache but leave previous messages."
1176    );
1177    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1178    text_thread.update(cx, |text_thread, cx| {
1179        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1180    });
1181    assert_eq!(
1182        messages_cache(&text_thread, cx)
1183            .iter()
1184            .map(|(_, cache)| cache
1185                .as_ref()
1186                .map_or(None, |cache| Some(cache.status.clone())))
1187            .collect::<Vec<Option<CacheStatus>>>(),
1188        vec![
1189            Some(CacheStatus::Pending),
1190            Some(CacheStatus::Pending),
1191            Some(CacheStatus::Pending),
1192            None
1193        ],
1194        "Modifying a message should invalidate all future messages."
1195    );
1196}
1197
1198#[gpui::test]
1199async fn test_summarization(cx: &mut TestAppContext) {
1200    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1201
1202    // Initial state should be pending
1203    text_thread.read_with(cx, |text_thread, _| {
1204        assert!(matches!(text_thread.summary(), TextThreadSummary::Pending));
1205        assert_eq!(
1206            text_thread.summary().or_default(),
1207            TextThreadSummary::DEFAULT
1208        );
1209    });
1210
1211    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1212        text_thread.message_anchors[0].clone()
1213    });
1214    text_thread.update(cx, |context, cx| {
1215        context
1216            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1217            .unwrap();
1218    });
1219
1220    // Send a message
1221    text_thread.update(cx, |text_thread, cx| {
1222        text_thread.assist(cx);
1223    });
1224
1225    simulate_successful_response(&fake_model, cx);
1226
1227    // Should start generating summary when there are >= 2 messages
1228    text_thread.read_with(cx, |text_thread, _| {
1229        assert!(!text_thread.summary().content().unwrap().done);
1230    });
1231
1232    cx.run_until_parked();
1233    fake_model.send_last_completion_stream_text_chunk("Brief");
1234    fake_model.send_last_completion_stream_text_chunk(" Introduction");
1235    fake_model.end_last_completion_stream();
1236    cx.run_until_parked();
1237
1238    // Summary should be set
1239    text_thread.read_with(cx, |text_thread, _| {
1240        assert_eq!(text_thread.summary().or_default(), "Brief Introduction");
1241    });
1242
1243    // We should be able to manually set a summary
1244    text_thread.update(cx, |text_thread, cx| {
1245        text_thread.set_custom_summary("Brief Intro".into(), cx);
1246    });
1247
1248    text_thread.read_with(cx, |text_thread, _| {
1249        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1250    });
1251}
1252
1253#[gpui::test]
1254async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1255    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1256
1257    test_summarize_error(&fake_model, &text_thread, cx);
1258
1259    // Now we should be able to set a summary
1260    text_thread.update(cx, |text_thread, cx| {
1261        text_thread.set_custom_summary("Brief Intro".into(), cx);
1262    });
1263
1264    text_thread.read_with(cx, |text_thread, _| {
1265        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1266    });
1267}
1268
1269#[gpui::test]
1270async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1271    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1272
1273    test_summarize_error(&fake_model, &text_thread, cx);
1274
1275    // Sending another message should not trigger another summarize request
1276    text_thread.update(cx, |text_thread, cx| {
1277        text_thread.assist(cx);
1278    });
1279
1280    simulate_successful_response(&fake_model, cx);
1281
1282    text_thread.read_with(cx, |text_thread, _| {
1283        // State is still Error, not Generating
1284        assert!(matches!(text_thread.summary(), TextThreadSummary::Error));
1285    });
1286
1287    // But the summarize request can be invoked manually
1288    text_thread.update(cx, |text_thread, cx| {
1289        text_thread.summarize(true, cx);
1290    });
1291
1292    text_thread.read_with(cx, |text_thread, _| {
1293        assert!(!text_thread.summary().content().unwrap().done);
1294    });
1295
1296    cx.run_until_parked();
1297    fake_model.send_last_completion_stream_text_chunk("A successful summary");
1298    fake_model.end_last_completion_stream();
1299    cx.run_until_parked();
1300
1301    text_thread.read_with(cx, |text_thread, _| {
1302        assert_eq!(text_thread.summary().or_default(), "A successful summary");
1303    });
1304}
1305
1306fn test_summarize_error(
1307    model: &Arc<FakeLanguageModel>,
1308    text_thread: &Entity<TextThread>,
1309    cx: &mut TestAppContext,
1310) {
1311    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1312        text_thread.message_anchors[0].clone()
1313    });
1314    text_thread.update(cx, |text_thread, cx| {
1315        text_thread
1316            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1317            .unwrap();
1318    });
1319
1320    // Send a message
1321    text_thread.update(cx, |text_thread, cx| {
1322        text_thread.assist(cx);
1323    });
1324
1325    simulate_successful_response(model, cx);
1326
1327    text_thread.read_with(cx, |text_thread, _| {
1328        assert!(!text_thread.summary().content().unwrap().done);
1329    });
1330
1331    // Simulate summary request ending
1332    cx.run_until_parked();
1333    model.end_last_completion_stream();
1334    cx.run_until_parked();
1335
1336    // State is set to Error and default message
1337    text_thread.read_with(cx, |text_thread, _| {
1338        assert_eq!(*text_thread.summary(), TextThreadSummary::Error);
1339        assert_eq!(
1340            text_thread.summary().or_default(),
1341            TextThreadSummary::DEFAULT
1342        );
1343    });
1344}
1345
1346fn setup_context_editor_with_fake_model(
1347    cx: &mut TestAppContext,
1348) -> (Entity<TextThread>, Arc<FakeLanguageModel>) {
1349    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1350
1351    let fake_provider = Arc::new(FakeLanguageModelProvider::default());
1352    let fake_model = Arc::new(fake_provider.test_model());
1353
1354    cx.update(|cx| {
1355        init_test(cx);
1356        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1357            let configured_model = ConfiguredModel {
1358                provider: fake_provider.clone(),
1359                model: fake_model.clone(),
1360            };
1361            registry.set_default_model(Some(configured_model.clone()), cx);
1362            registry.set_thread_summary_model(Some(configured_model), cx);
1363        })
1364    });
1365
1366    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1367    let context = cx.new(|cx| {
1368        TextThread::local(
1369            registry,
1370            None,
1371            None,
1372            prompt_builder.clone(),
1373            Arc::new(SlashCommandWorkingSet::default()),
1374            cx,
1375        )
1376    });
1377
1378    (context, fake_model)
1379}
1380
1381fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1382    cx.run_until_parked();
1383    fake_model.send_last_completion_stream_text_chunk("Assistant response");
1384    fake_model.end_last_completion_stream();
1385    cx.run_until_parked();
1386}
1387
1388fn messages(context: &Entity<TextThread>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1389    context
1390        .read(cx)
1391        .messages(cx)
1392        .map(|message| (message.id, message.role, message.offset_range))
1393        .collect()
1394}
1395
1396fn messages_cache(
1397    context: &Entity<TextThread>,
1398    cx: &App,
1399) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1400    context
1401        .read(cx)
1402        .messages(cx)
1403        .map(|message| (message.id, message.cache))
1404        .collect()
1405}
1406
1407fn init_test(cx: &mut App) {
1408    let settings_store = SettingsStore::test(cx);
1409    prompt_store::init(cx);
1410    LanguageModelRegistry::test(cx);
1411    cx.set_global(settings_store);
1412}
1413
1414#[derive(Clone)]
1415struct FakeSlashCommand(String);
1416
1417impl SlashCommand for FakeSlashCommand {
1418    fn name(&self) -> String {
1419        self.0.clone()
1420    }
1421
1422    fn description(&self) -> String {
1423        format!("Fake slash command: {}", self.0)
1424    }
1425
1426    fn menu_text(&self) -> String {
1427        format!("Run fake command: {}", self.0)
1428    }
1429
1430    fn complete_argument(
1431        self: Arc<Self>,
1432        _arguments: &[String],
1433        _cancel: Arc<AtomicBool>,
1434        _workspace: Option<WeakEntity<Workspace>>,
1435        _window: &mut Window,
1436        _cx: &mut App,
1437    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1438        Task::ready(Ok(vec![]))
1439    }
1440
1441    fn requires_argument(&self) -> bool {
1442        false
1443    }
1444
1445    fn run(
1446        self: Arc<Self>,
1447        _arguments: &[String],
1448        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1449        _context_buffer: BufferSnapshot,
1450        _workspace: WeakEntity<Workspace>,
1451        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1452        _window: &mut Window,
1453        _cx: &mut App,
1454    ) -> Task<SlashCommandResult> {
1455        Task::ready(Ok(SlashCommandOutput {
1456            text: format!("Executed fake command: {}", self.0),
1457            sections: vec![],
1458            run_commands_in_text: false,
1459        }
1460        .into_event_stream()))
1461    }
1462}