assistant_text_thread_tests.rs

   1use crate::{
   2    CacheStatus, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, TextThread,
   3    TextThreadEvent, TextThreadId, TextThreadOperation, TextThreadSummary,
   4};
   5use anyhow::Result;
   6use assistant_slash_command::{
   7    ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
   8    SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
   9};
  10use assistant_slash_commands::FileSlashCommand;
  11use collections::{HashMap, HashSet};
  12use fs::FakeFs;
  13use futures::{
  14    channel::mpsc,
  15    stream::{self, StreamExt},
  16};
  17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
  18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
  19use language_model::{
  20    ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
  21    fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
  22};
  23use parking_lot::Mutex;
  24use pretty_assertions::assert_eq;
  25use prompt_store::PromptBuilder;
  26use rand::prelude::*;
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{
  30    cell::RefCell,
  31    env,
  32    ops::Range,
  33    path::Path,
  34    rc::Rc,
  35    sync::{Arc, atomic::AtomicBool},
  36};
  37use text::{ReplicaId, ToOffset, network::Network};
  38use ui::{IconName, Window};
  39use unindent::Unindent;
  40use util::RandomCharIter;
  41use workspace::Workspace;
  42
  43#[gpui::test]
  44fn test_inserting_and_removing_messages(cx: &mut App) {
  45    init_test(cx);
  46
  47    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
  48    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
  49    let text_thread = cx.new(|cx| {
  50        TextThread::local(
  51            registry,
  52            None,
  53            None,
  54            prompt_builder.clone(),
  55            Arc::new(SlashCommandWorkingSet::default()),
  56            cx,
  57        )
  58    });
  59    let buffer = text_thread.read(cx).buffer().clone();
  60
  61    let message_1 = text_thread.read(cx).message_anchors[0].clone();
  62    assert_eq!(
  63        messages(&text_thread, cx),
  64        vec![(message_1.id, Role::User, 0..0)]
  65    );
  66
  67    let message_2 = text_thread.update(cx, |context, cx| {
  68        context
  69            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
  70            .unwrap()
  71    });
  72    assert_eq!(
  73        messages(&text_thread, cx),
  74        vec![
  75            (message_1.id, Role::User, 0..1),
  76            (message_2.id, Role::Assistant, 1..1)
  77        ]
  78    );
  79
  80    buffer.update(cx, |buffer, cx| {
  81        buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
  82    });
  83    assert_eq!(
  84        messages(&text_thread, cx),
  85        vec![
  86            (message_1.id, Role::User, 0..2),
  87            (message_2.id, Role::Assistant, 2..3)
  88        ]
  89    );
  90
  91    let message_3 = text_thread.update(cx, |context, cx| {
  92        context
  93            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
  94            .unwrap()
  95    });
  96    assert_eq!(
  97        messages(&text_thread, cx),
  98        vec![
  99            (message_1.id, Role::User, 0..2),
 100            (message_2.id, Role::Assistant, 2..4),
 101            (message_3.id, Role::User, 4..4)
 102        ]
 103    );
 104
 105    let message_4 = text_thread.update(cx, |context, cx| {
 106        context
 107            .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 108            .unwrap()
 109    });
 110    assert_eq!(
 111        messages(&text_thread, cx),
 112        vec![
 113            (message_1.id, Role::User, 0..2),
 114            (message_2.id, Role::Assistant, 2..4),
 115            (message_4.id, Role::User, 4..5),
 116            (message_3.id, Role::User, 5..5),
 117        ]
 118    );
 119
 120    buffer.update(cx, |buffer, cx| {
 121        buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
 122    });
 123    assert_eq!(
 124        messages(&text_thread, cx),
 125        vec![
 126            (message_1.id, Role::User, 0..2),
 127            (message_2.id, Role::Assistant, 2..4),
 128            (message_4.id, Role::User, 4..6),
 129            (message_3.id, Role::User, 6..7),
 130        ]
 131    );
 132
 133    // Deleting across message boundaries merges the messages.
 134    buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
 135    assert_eq!(
 136        messages(&text_thread, cx),
 137        vec![
 138            (message_1.id, Role::User, 0..3),
 139            (message_3.id, Role::User, 3..4),
 140        ]
 141    );
 142
 143    // Undoing the deletion should also undo the merge.
 144    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 145    assert_eq!(
 146        messages(&text_thread, cx),
 147        vec![
 148            (message_1.id, Role::User, 0..2),
 149            (message_2.id, Role::Assistant, 2..4),
 150            (message_4.id, Role::User, 4..6),
 151            (message_3.id, Role::User, 6..7),
 152        ]
 153    );
 154
 155    // Redoing the deletion should also redo the merge.
 156    buffer.update(cx, |buffer, cx| buffer.redo(cx));
 157    assert_eq!(
 158        messages(&text_thread, cx),
 159        vec![
 160            (message_1.id, Role::User, 0..3),
 161            (message_3.id, Role::User, 3..4),
 162        ]
 163    );
 164
 165    // Ensure we can still insert after a merged message.
 166    let message_5 = text_thread.update(cx, |context, cx| {
 167        context
 168            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 169            .unwrap()
 170    });
 171    assert_eq!(
 172        messages(&text_thread, cx),
 173        vec![
 174            (message_1.id, Role::User, 0..3),
 175            (message_5.id, Role::System, 3..4),
 176            (message_3.id, Role::User, 4..5)
 177        ]
 178    );
 179}
 180
 181#[gpui::test]
 182fn test_message_splitting(cx: &mut App) {
 183    init_test(cx);
 184
 185    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 186
 187    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 188    let text_thread = cx.new(|cx| {
 189        TextThread::local(
 190            registry.clone(),
 191            None,
 192            None,
 193            prompt_builder.clone(),
 194            Arc::new(SlashCommandWorkingSet::default()),
 195            cx,
 196        )
 197    });
 198    let buffer = text_thread.read(cx).buffer().clone();
 199
 200    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 201    assert_eq!(
 202        messages(&text_thread, cx),
 203        vec![(message_1.id, Role::User, 0..0)]
 204    );
 205
 206    buffer.update(cx, |buffer, cx| {
 207        buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
 208    });
 209
 210    let (_, message_2) =
 211        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 212    let message_2 = message_2.unwrap();
 213
 214    // We recycle newlines in the middle of a split message
 215    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
 216    assert_eq!(
 217        messages(&text_thread, cx),
 218        vec![
 219            (message_1.id, Role::User, 0..4),
 220            (message_2.id, Role::User, 4..16),
 221        ]
 222    );
 223
 224    let (_, message_3) =
 225        text_thread.update(cx, |text_thread, cx| text_thread.split_message(3..3, cx));
 226    let message_3 = message_3.unwrap();
 227
 228    // We don't recycle newlines at the end of a split message
 229    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 230    assert_eq!(
 231        messages(&text_thread, cx),
 232        vec![
 233            (message_1.id, Role::User, 0..4),
 234            (message_3.id, Role::User, 4..5),
 235            (message_2.id, Role::User, 5..17),
 236        ]
 237    );
 238
 239    let (_, message_4) =
 240        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 241    let message_4 = message_4.unwrap();
 242    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
 243    assert_eq!(
 244        messages(&text_thread, cx),
 245        vec![
 246            (message_1.id, Role::User, 0..4),
 247            (message_3.id, Role::User, 4..5),
 248            (message_2.id, Role::User, 5..9),
 249            (message_4.id, Role::User, 9..17),
 250        ]
 251    );
 252
 253    let (_, message_5) =
 254        text_thread.update(cx, |text_thread, cx| text_thread.split_message(9..9, cx));
 255    let message_5 = message_5.unwrap();
 256    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
 257    assert_eq!(
 258        messages(&text_thread, cx),
 259        vec![
 260            (message_1.id, Role::User, 0..4),
 261            (message_3.id, Role::User, 4..5),
 262            (message_2.id, Role::User, 5..9),
 263            (message_4.id, Role::User, 9..10),
 264            (message_5.id, Role::User, 10..18),
 265        ]
 266    );
 267
 268    let (message_6, message_7) =
 269        text_thread.update(cx, |text_thread, cx| text_thread.split_message(14..16, cx));
 270    let message_6 = message_6.unwrap();
 271    let message_7 = message_7.unwrap();
 272    assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
 273    assert_eq!(
 274        messages(&text_thread, cx),
 275        vec![
 276            (message_1.id, Role::User, 0..4),
 277            (message_3.id, Role::User, 4..5),
 278            (message_2.id, Role::User, 5..9),
 279            (message_4.id, Role::User, 9..10),
 280            (message_5.id, Role::User, 10..14),
 281            (message_6.id, Role::User, 14..17),
 282            (message_7.id, Role::User, 17..19),
 283        ]
 284    );
 285}
 286
 287#[gpui::test]
 288fn test_messages_for_offsets(cx: &mut App) {
 289    init_test(cx);
 290
 291    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 292    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 293    let text_thread = cx.new(|cx| {
 294        TextThread::local(
 295            registry,
 296            None,
 297            None,
 298            prompt_builder.clone(),
 299            Arc::new(SlashCommandWorkingSet::default()),
 300            cx,
 301        )
 302    });
 303    let buffer = text_thread.read(cx).buffer().clone();
 304
 305    let message_1 = text_thread.read(cx).message_anchors[0].clone();
 306    assert_eq!(
 307        messages(&text_thread, cx),
 308        vec![(message_1.id, Role::User, 0..0)]
 309    );
 310
 311    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
 312    let message_2 = text_thread
 313        .update(cx, |text_thread, cx| {
 314            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
 315        })
 316        .unwrap();
 317    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 318
 319    let message_3 = text_thread
 320        .update(cx, |text_thread, cx| {
 321            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
 322        })
 323        .unwrap();
 324    buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
 325
 326    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
 327    assert_eq!(
 328        messages(&text_thread, cx),
 329        vec![
 330            (message_1.id, Role::User, 0..4),
 331            (message_2.id, Role::User, 4..8),
 332            (message_3.id, Role::User, 8..11)
 333        ]
 334    );
 335
 336    assert_eq!(
 337        message_ids_for_offsets(&text_thread, &[0, 4, 9], cx),
 338        [message_1.id, message_2.id, message_3.id]
 339    );
 340    assert_eq!(
 341        message_ids_for_offsets(&text_thread, &[0, 1, 11], cx),
 342        [message_1.id, message_3.id]
 343    );
 344
 345    let message_4 = text_thread
 346        .update(cx, |text_thread, cx| {
 347            text_thread.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
 348        })
 349        .unwrap();
 350    assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
 351    assert_eq!(
 352        messages(&text_thread, cx),
 353        vec![
 354            (message_1.id, Role::User, 0..4),
 355            (message_2.id, Role::User, 4..8),
 356            (message_3.id, Role::User, 8..12),
 357            (message_4.id, Role::User, 12..12)
 358        ]
 359    );
 360    assert_eq!(
 361        message_ids_for_offsets(&text_thread, &[0, 4, 8, 12], cx),
 362        [message_1.id, message_2.id, message_3.id, message_4.id]
 363    );
 364
 365    fn message_ids_for_offsets(
 366        context: &Entity<TextThread>,
 367        offsets: &[usize],
 368        cx: &App,
 369    ) -> Vec<MessageId> {
 370        context
 371            .read(cx)
 372            .messages_for_offsets(offsets.iter().copied(), cx)
 373            .into_iter()
 374            .map(|message| message.id)
 375            .collect()
 376    }
 377}
 378
 379#[gpui::test]
 380async fn test_slash_commands(cx: &mut TestAppContext) {
 381    cx.update(init_test);
 382
 383    let fs = FakeFs::new(cx.background_executor.clone());
 384
 385    fs.insert_tree(
 386        "/test",
 387        json!({
 388            "src": {
 389                "lib.rs": "fn one() -> usize { 1 }",
 390                "main.rs": "
 391                    use crate::one;
 392                    fn main() { one(); }
 393                ".unindent(),
 394            }
 395        }),
 396    )
 397    .await;
 398
 399    let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
 400    slash_command_registry.register_command(FileSlashCommand, false);
 401
 402    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 403    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 404    let text_thread = cx.new(|cx| {
 405        TextThread::local(
 406            registry.clone(),
 407            None,
 408            None,
 409            prompt_builder.clone(),
 410            Arc::new(SlashCommandWorkingSet::default()),
 411            cx,
 412        )
 413    });
 414
 415    #[derive(Default)]
 416    struct ContextRanges {
 417        parsed_commands: HashSet<Range<language::Anchor>>,
 418        command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
 419        output_sections: HashSet<Range<language::Anchor>>,
 420    }
 421
 422    let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
 423    text_thread.update(cx, |_, cx| {
 424        cx.subscribe(&text_thread, {
 425            let context_ranges = context_ranges.clone();
 426            move |text_thread, _, event, _| {
 427                let mut context_ranges = context_ranges.borrow_mut();
 428                match event {
 429                    TextThreadEvent::InvokedSlashCommandChanged { command_id } => {
 430                        let command = text_thread.invoked_slash_command(command_id).unwrap();
 431                        context_ranges
 432                            .command_outputs
 433                            .insert(*command_id, command.range.clone());
 434                    }
 435                    TextThreadEvent::ParsedSlashCommandsUpdated { removed, updated } => {
 436                        for range in removed {
 437                            context_ranges.parsed_commands.remove(range);
 438                        }
 439                        for command in updated {
 440                            context_ranges
 441                                .parsed_commands
 442                                .insert(command.source_range.clone());
 443                        }
 444                    }
 445                    TextThreadEvent::SlashCommandOutputSectionAdded { section } => {
 446                        context_ranges.output_sections.insert(section.range.clone());
 447                    }
 448                    _ => {}
 449                }
 450            }
 451        })
 452        .detach();
 453    });
 454
 455    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 456
 457    // Insert a slash command
 458    buffer.update(cx, |buffer, cx| {
 459        buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
 460    });
 461    assert_text_and_context_ranges(
 462        &buffer,
 463        &context_ranges,
 464        &"
 465        «/file src/lib.rs»"
 466            .unindent(),
 467        cx,
 468    );
 469
 470    // Edit the argument of the slash command.
 471    buffer.update(cx, |buffer, cx| {
 472        let edit_offset = buffer.text().find("lib.rs").unwrap();
 473        buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
 474    });
 475    assert_text_and_context_ranges(
 476        &buffer,
 477        &context_ranges,
 478        &"
 479        «/file src/main.rs»"
 480            .unindent(),
 481        cx,
 482    );
 483
 484    // Edit the name of the slash command, using one that doesn't exist.
 485    buffer.update(cx, |buffer, cx| {
 486        let edit_offset = buffer.text().find("/file").unwrap();
 487        buffer.edit(
 488            [(edit_offset..edit_offset + "/file".len(), "/unknown")],
 489            None,
 490            cx,
 491        );
 492    });
 493    assert_text_and_context_ranges(
 494        &buffer,
 495        &context_ranges,
 496        &"
 497        /unknown src/main.rs"
 498            .unindent(),
 499        cx,
 500    );
 501
 502    // Undoing the insertion of an non-existent slash command resorts the previous one.
 503    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 504    assert_text_and_context_ranges(
 505        &buffer,
 506        &context_ranges,
 507        &"
 508        «/file src/main.rs»"
 509            .unindent(),
 510        cx,
 511    );
 512
 513    let (command_output_tx, command_output_rx) = mpsc::unbounded();
 514    text_thread.update(cx, |text_thread, cx| {
 515        let command_source_range = text_thread.parsed_slash_commands[0].source_range.clone();
 516        text_thread.insert_command_output(
 517            command_source_range,
 518            "file",
 519            Task::ready(Ok(command_output_rx.boxed())),
 520            true,
 521            cx,
 522        );
 523    });
 524    assert_text_and_context_ranges(
 525        &buffer,
 526        &context_ranges,
 527        &"
 528        ⟦«/file src/main.rs»
 529        …⟧
 530        "
 531        .unindent(),
 532        cx,
 533    );
 534
 535    command_output_tx
 536        .unbounded_send(Ok(SlashCommandEvent::StartSection {
 537            icon: IconName::Ai,
 538            label: "src/main.rs".into(),
 539            metadata: None,
 540        }))
 541        .unwrap();
 542    command_output_tx
 543        .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
 544        .unwrap();
 545    cx.run_until_parked();
 546    assert_text_and_context_ranges(
 547        &buffer,
 548        &context_ranges,
 549        &"
 550        ⟦«/file src/main.rs»
 551        src/main.rs…⟧
 552        "
 553        .unindent(),
 554        cx,
 555    );
 556
 557    command_output_tx
 558        .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
 559        .unwrap();
 560    cx.run_until_parked();
 561    assert_text_and_context_ranges(
 562        &buffer,
 563        &context_ranges,
 564        &"
 565        ⟦«/file src/main.rs»
 566        src/main.rs
 567        fn main() {}…⟧
 568        "
 569        .unindent(),
 570        cx,
 571    );
 572
 573    command_output_tx
 574        .unbounded_send(Ok(SlashCommandEvent::EndSection))
 575        .unwrap();
 576    cx.run_until_parked();
 577    assert_text_and_context_ranges(
 578        &buffer,
 579        &context_ranges,
 580        &"
 581        ⟦«/file src/main.rs»
 582        ⟪src/main.rs
 583        fn main() {}⟫…⟧
 584        "
 585        .unindent(),
 586        cx,
 587    );
 588
 589    drop(command_output_tx);
 590    cx.run_until_parked();
 591    assert_text_and_context_ranges(
 592        &buffer,
 593        &context_ranges,
 594        &"
 595        ⟦⟪src/main.rs
 596        fn main() {}⟫⟧
 597        "
 598        .unindent(),
 599        cx,
 600    );
 601
 602    #[track_caller]
 603    fn assert_text_and_context_ranges(
 604        buffer: &Entity<Buffer>,
 605        ranges: &RefCell<ContextRanges>,
 606        expected_marked_text: &str,
 607        cx: &mut TestAppContext,
 608    ) {
 609        let mut actual_marked_text = String::new();
 610        buffer.update(cx, |buffer, _| {
 611            struct Endpoint {
 612                offset: usize,
 613                marker: char,
 614            }
 615
 616            let ranges = ranges.borrow();
 617            let mut endpoints = Vec::new();
 618            for range in ranges.command_outputs.values() {
 619                endpoints.push(Endpoint {
 620                    offset: range.start.to_offset(buffer),
 621                    marker: '',
 622                });
 623            }
 624            for range in ranges.parsed_commands.iter() {
 625                endpoints.push(Endpoint {
 626                    offset: range.start.to_offset(buffer),
 627                    marker: '«',
 628                });
 629            }
 630            for range in ranges.output_sections.iter() {
 631                endpoints.push(Endpoint {
 632                    offset: range.start.to_offset(buffer),
 633                    marker: '',
 634                });
 635            }
 636
 637            for range in ranges.output_sections.iter() {
 638                endpoints.push(Endpoint {
 639                    offset: range.end.to_offset(buffer),
 640                    marker: '',
 641                });
 642            }
 643            for range in ranges.parsed_commands.iter() {
 644                endpoints.push(Endpoint {
 645                    offset: range.end.to_offset(buffer),
 646                    marker: '»',
 647                });
 648            }
 649            for range in ranges.command_outputs.values() {
 650                endpoints.push(Endpoint {
 651                    offset: range.end.to_offset(buffer),
 652                    marker: '',
 653                });
 654            }
 655
 656            endpoints.sort_by_key(|endpoint| endpoint.offset);
 657            let mut offset = 0;
 658            for endpoint in endpoints {
 659                actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
 660                actual_marked_text.push(endpoint.marker);
 661                offset = endpoint.offset;
 662            }
 663            actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
 664        });
 665
 666        assert_eq!(actual_marked_text, expected_marked_text);
 667    }
 668}
 669
 670#[gpui::test]
 671async fn test_serialization(cx: &mut TestAppContext) {
 672    cx.update(init_test);
 673
 674    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
 675    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 676    let text_thread = cx.new(|cx| {
 677        TextThread::local(
 678            registry.clone(),
 679            None,
 680            None,
 681            prompt_builder.clone(),
 682            Arc::new(SlashCommandWorkingSet::default()),
 683            cx,
 684        )
 685    });
 686    let buffer = text_thread.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 687    let message_0 = text_thread.read_with(cx, |text_thread, _| text_thread.message_anchors[0].id);
 688    let message_1 = text_thread.update(cx, |text_thread, cx| {
 689        text_thread
 690            .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
 691            .unwrap()
 692    });
 693    let message_2 = text_thread.update(cx, |text_thread, cx| {
 694        text_thread
 695            .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
 696            .unwrap()
 697    });
 698    buffer.update(cx, |buffer, cx| {
 699        buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
 700        buffer.finalize_last_transaction();
 701    });
 702    let _message_3 = text_thread.update(cx, |text_thread, cx| {
 703        text_thread
 704            .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
 705            .unwrap()
 706    });
 707    buffer.update(cx, |buffer, cx| buffer.undo(cx));
 708    assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
 709    assert_eq!(
 710        cx.read(|cx| messages(&text_thread, cx)),
 711        [
 712            (message_0, Role::User, 0..2),
 713            (message_1.id, Role::Assistant, 2..6),
 714            (message_2.id, Role::System, 6..6),
 715        ]
 716    );
 717
 718    let serialized_context = text_thread.read_with(cx, |text_thread, cx| text_thread.serialize(cx));
 719    let deserialized_context = cx.new(|cx| {
 720        TextThread::deserialize(
 721            serialized_context,
 722            Path::new("").into(),
 723            registry.clone(),
 724            prompt_builder.clone(),
 725            Arc::new(SlashCommandWorkingSet::default()),
 726            None,
 727            None,
 728            cx,
 729        )
 730    });
 731    let deserialized_buffer =
 732        deserialized_context.read_with(cx, |text_thread, _| text_thread.buffer().clone());
 733    assert_eq!(
 734        deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
 735        "a\nb\nc\n"
 736    );
 737    assert_eq!(
 738        cx.read(|cx| messages(&deserialized_context, cx)),
 739        [
 740            (message_0, Role::User, 0..2),
 741            (message_1.id, Role::Assistant, 2..6),
 742            (message_2.id, Role::System, 6..6),
 743        ]
 744    );
 745}
 746
 747#[gpui::test(iterations = 25)]
 748async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
 749    cx.update(init_test);
 750
 751    let min_peers = env::var("MIN_PEERS")
 752        .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
 753        .unwrap_or(2);
 754    let max_peers = env::var("MAX_PEERS")
 755        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 756        .unwrap_or(5);
 757    let operations = env::var("OPERATIONS")
 758        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 759        .unwrap_or(50);
 760
 761    let slash_commands = cx.update(SlashCommandRegistry::default_global);
 762    slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
 763    slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
 764    slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
 765
 766    let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
 767    let network = Arc::new(Mutex::new(Network::new(rng.clone())));
 768    let mut text_threads = Vec::new();
 769
 770    let num_peers = rng.random_range(min_peers..=max_peers);
 771    let context_id = TextThreadId::new();
 772    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
 773    for i in 0..num_peers {
 774        let context = cx.new(|cx| {
 775            TextThread::new(
 776                context_id.clone(),
 777                ReplicaId::new(i as u16),
 778                language::Capability::ReadWrite,
 779                registry.clone(),
 780                prompt_builder.clone(),
 781                Arc::new(SlashCommandWorkingSet::default()),
 782                None,
 783                None,
 784                cx,
 785            )
 786        });
 787
 788        cx.update(|cx| {
 789            cx.subscribe(&context, {
 790                let network = network.clone();
 791                move |_, event, _| {
 792                    if let TextThreadEvent::Operation(op) = event {
 793                        network
 794                            .lock()
 795                            .broadcast(ReplicaId::new(i as u16), vec![op.to_proto()]);
 796                    }
 797                }
 798            })
 799            .detach();
 800        });
 801
 802        text_threads.push(context);
 803        network.lock().add_peer(ReplicaId::new(i as u16));
 804    }
 805
 806    let mut mutation_count = operations;
 807
 808    while mutation_count > 0
 809        || !network.lock().is_idle()
 810        || network.lock().contains_disconnected_peers()
 811    {
 812        let context_index = rng.random_range(0..text_threads.len());
 813        let text_thread = &text_threads[context_index];
 814
 815        match rng.random_range(0..100) {
 816            0..=29 if mutation_count > 0 => {
 817                log::info!("Context {}: edit buffer", context_index);
 818                text_thread.update(cx, |text_thread, cx| {
 819                    text_thread
 820                        .buffer()
 821                        .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
 822                });
 823                mutation_count -= 1;
 824            }
 825            30..=44 if mutation_count > 0 => {
 826                text_thread.update(cx, |text_thread, cx| {
 827                    let range = text_thread.buffer().read(cx).random_byte_range(0, &mut rng);
 828                    log::info!("Context {}: split message at {:?}", context_index, range);
 829                    text_thread.split_message(range, cx);
 830                });
 831                mutation_count -= 1;
 832            }
 833            45..=59 if mutation_count > 0 => {
 834                text_thread.update(cx, |text_thread, cx| {
 835                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 836                        let role = *[Role::User, Role::Assistant, Role::System]
 837                            .choose(&mut rng)
 838                            .unwrap();
 839                        log::info!(
 840                            "Context {}: insert message after {:?} with {:?}",
 841                            context_index,
 842                            message.id,
 843                            role
 844                        );
 845                        text_thread.insert_message_after(message.id, role, MessageStatus::Done, cx);
 846                    }
 847                });
 848                mutation_count -= 1;
 849            }
 850            60..=74 if mutation_count > 0 => {
 851                text_thread.update(cx, |text_thread, cx| {
 852                    let command_text = "/".to_string()
 853                        + slash_commands
 854                            .command_names()
 855                            .choose(&mut rng)
 856                            .unwrap()
 857                            .clone()
 858                            .as_ref();
 859
 860                    let command_range = text_thread.buffer().update(cx, |buffer, cx| {
 861                        let offset = buffer.random_byte_range(0, &mut rng).start;
 862                        buffer.edit(
 863                            [(offset..offset, format!("\n{}\n", command_text))],
 864                            None,
 865                            cx,
 866                        );
 867                        offset + 1..offset + 1 + command_text.len()
 868                    });
 869
 870                    let output_text = RandomCharIter::new(&mut rng)
 871                        .filter(|c| *c != '\r')
 872                        .take(10)
 873                        .collect::<String>();
 874
 875                    let mut events = vec![Ok(SlashCommandEvent::StartMessage {
 876                        role: Role::User,
 877                        merge_same_roles: true,
 878                    })];
 879
 880                    let num_sections = rng.random_range(0..=3);
 881                    let mut section_start = 0;
 882                    for _ in 0..num_sections {
 883                        let mut section_end = rng.random_range(section_start..=output_text.len());
 884                        while !output_text.is_char_boundary(section_end) {
 885                            section_end += 1;
 886                        }
 887                        events.push(Ok(SlashCommandEvent::StartSection {
 888                            icon: IconName::Ai,
 889                            label: "section".into(),
 890                            metadata: None,
 891                        }));
 892                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 893                            text: output_text[section_start..section_end].to_string(),
 894                            run_commands_in_text: false,
 895                        })));
 896                        events.push(Ok(SlashCommandEvent::EndSection));
 897                        section_start = section_end;
 898                    }
 899
 900                    if section_start < output_text.len() {
 901                        events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
 902                            text: output_text[section_start..].to_string(),
 903                            run_commands_in_text: false,
 904                        })));
 905                    }
 906
 907                    log::info!(
 908                        "Context {}: insert slash command output at {:?} with {:?} events",
 909                        context_index,
 910                        command_range,
 911                        events.len()
 912                    );
 913
 914                    let command_range = text_thread
 915                        .buffer()
 916                        .read(cx)
 917                        .anchor_after(command_range.start)
 918                        ..text_thread
 919                            .buffer()
 920                            .read(cx)
 921                            .anchor_after(command_range.end);
 922                    text_thread.insert_command_output(
 923                        command_range,
 924                        "/command",
 925                        Task::ready(Ok(stream::iter(events).boxed())),
 926                        true,
 927                        cx,
 928                    );
 929                });
 930                cx.run_until_parked();
 931                mutation_count -= 1;
 932            }
 933            75..=84 if mutation_count > 0 => {
 934                text_thread.update(cx, |text_thread, cx| {
 935                    if let Some(message) = text_thread.messages(cx).choose(&mut rng) {
 936                        let new_status = match rng.random_range(0..3) {
 937                            0 => MessageStatus::Done,
 938                            1 => MessageStatus::Pending,
 939                            _ => MessageStatus::Error(SharedString::from("Random error")),
 940                        };
 941                        log::info!(
 942                            "Context {}: update message {:?} status to {:?}",
 943                            context_index,
 944                            message.id,
 945                            new_status
 946                        );
 947                        text_thread.update_metadata(message.id, cx, |metadata| {
 948                            metadata.status = new_status;
 949                        });
 950                    }
 951                });
 952                mutation_count -= 1;
 953            }
 954            _ => {
 955                let replica_id = ReplicaId::new(context_index as u16);
 956                if network.lock().is_disconnected(replica_id) {
 957                    network.lock().reconnect_peer(replica_id, ReplicaId::new(0));
 958
 959                    let (ops_to_send, ops_to_receive) = cx.read(|cx| {
 960                        let host_context = &text_threads[0].read(cx);
 961                        let guest_context = text_thread.read(cx);
 962                        (
 963                            guest_context.serialize_ops(&host_context.version(cx), cx),
 964                            host_context.serialize_ops(&guest_context.version(cx), cx),
 965                        )
 966                    });
 967                    let ops_to_send = ops_to_send.await;
 968                    let ops_to_receive = ops_to_receive
 969                        .await
 970                        .into_iter()
 971                        .map(TextThreadOperation::from_proto)
 972                        .collect::<Result<Vec<_>>>()
 973                        .unwrap();
 974                    log::info!(
 975                        "Context {}: reconnecting. Sent {} operations, received {} operations",
 976                        context_index,
 977                        ops_to_send.len(),
 978                        ops_to_receive.len()
 979                    );
 980
 981                    network.lock().broadcast(replica_id, ops_to_send);
 982                    text_thread.update(cx, |text_thread, cx| {
 983                        text_thread.apply_ops(ops_to_receive, cx)
 984                    });
 985                } else if rng.random_bool(0.1) && replica_id != ReplicaId::new(0) {
 986                    log::info!("Context {}: disconnecting", context_index);
 987                    network.lock().disconnect_peer(replica_id);
 988                } else if network.lock().has_unreceived(replica_id) {
 989                    log::info!("Context {}: applying operations", context_index);
 990                    let ops = network.lock().receive(replica_id);
 991                    let ops = ops
 992                        .into_iter()
 993                        .map(TextThreadOperation::from_proto)
 994                        .collect::<Result<Vec<_>>>()
 995                        .unwrap();
 996                    text_thread.update(cx, |text_thread, cx| text_thread.apply_ops(ops, cx));
 997                }
 998            }
 999        }
1000    }
1001
1002    cx.read(|cx| {
1003        let first_context = text_threads[0].read(cx);
1004        for text_thread in &text_threads[1..] {
1005            let text_thread = text_thread.read(cx);
1006            assert!(text_thread.pending_ops.is_empty(), "pending ops: {:?}", text_thread.pending_ops);
1007            assert_eq!(
1008                text_thread.buffer().read(cx).text(),
1009                first_context.buffer().read(cx).text(),
1010                "Context {:?} text != Context 0 text",
1011                text_thread.buffer().read(cx).replica_id()
1012            );
1013            assert_eq!(
1014                text_thread.message_anchors,
1015                first_context.message_anchors,
1016                "Context {:?} messages != Context 0 messages",
1017                text_thread.buffer().read(cx).replica_id()
1018            );
1019            assert_eq!(
1020                text_thread.messages_metadata,
1021                first_context.messages_metadata,
1022                "Context {:?} message metadata != Context 0 message metadata",
1023                text_thread.buffer().read(cx).replica_id()
1024            );
1025            assert_eq!(
1026                text_thread.slash_command_output_sections,
1027                first_context.slash_command_output_sections,
1028                "Context {:?} slash command output sections != Context 0 slash command output sections",
1029                text_thread.buffer().read(cx).replica_id()
1030            );
1031        }
1032    });
1033}
1034
1035#[gpui::test]
1036fn test_mark_cache_anchors(cx: &mut App) {
1037    init_test(cx);
1038
1039    let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1040    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1041    let text_thread = cx.new(|cx| {
1042        TextThread::local(
1043            registry,
1044            None,
1045            None,
1046            prompt_builder.clone(),
1047            Arc::new(SlashCommandWorkingSet::default()),
1048            cx,
1049        )
1050    });
1051    let buffer = text_thread.read(cx).buffer().clone();
1052
1053    // Create a test cache configuration
1054    let cache_configuration = &Some(LanguageModelCacheConfiguration {
1055        max_cache_anchors: 3,
1056        should_speculate: true,
1057        min_total_token: 10,
1058    });
1059
1060    let message_1 = text_thread.read(cx).message_anchors[0].clone();
1061
1062    text_thread.update(cx, |text_thread, cx| {
1063        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1064    });
1065
1066    assert_eq!(
1067        messages_cache(&text_thread, cx)
1068            .iter()
1069            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1070            .count(),
1071        0,
1072        "Empty messages should not have any cache anchors."
1073    );
1074
1075    buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1076    let message_2 = text_thread
1077        .update(cx, |text_thread, cx| {
1078            text_thread.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1079        })
1080        .unwrap();
1081
1082    buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1083    let message_3 = text_thread
1084        .update(cx, |text_thread, cx| {
1085            text_thread.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1086        })
1087        .unwrap();
1088    buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1089
1090    text_thread.update(cx, |text_thread, cx| {
1091        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1092    });
1093    assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1094    assert_eq!(
1095        messages_cache(&text_thread, cx)
1096            .iter()
1097            .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1098            .count(),
1099        0,
1100        "Messages should not be marked for cache before going over the token minimum."
1101    );
1102    text_thread.update(cx, |text_thread, _| {
1103        text_thread.token_count = Some(20);
1104    });
1105
1106    text_thread.update(cx, |text_thread, cx| {
1107        text_thread.mark_cache_anchors(cache_configuration, true, cx)
1108    });
1109    assert_eq!(
1110        messages_cache(&text_thread, cx)
1111            .iter()
1112            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1113            .collect::<Vec<bool>>(),
1114        vec![true, true, false],
1115        "Last message should not be an anchor on speculative request."
1116    );
1117
1118    text_thread
1119        .update(cx, |text_thread, cx| {
1120            text_thread.insert_message_after(
1121                message_3.id,
1122                Role::Assistant,
1123                MessageStatus::Pending,
1124                cx,
1125            )
1126        })
1127        .unwrap();
1128
1129    text_thread.update(cx, |text_thread, cx| {
1130        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1131    });
1132    assert_eq!(
1133        messages_cache(&text_thread, cx)
1134            .iter()
1135            .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor))
1136            .collect::<Vec<bool>>(),
1137        vec![false, true, true, false],
1138        "Most recent message should also be cached if not a speculative request."
1139    );
1140    text_thread.update(cx, |text_thread, cx| {
1141        text_thread.update_cache_status_for_completion(cx)
1142    });
1143    assert_eq!(
1144        messages_cache(&text_thread, cx)
1145            .iter()
1146            .map(|(_, cache)| cache
1147                .as_ref()
1148                .map_or(None, |cache| Some(cache.status.clone())))
1149            .collect::<Vec<Option<CacheStatus>>>(),
1150        vec![
1151            Some(CacheStatus::Cached),
1152            Some(CacheStatus::Cached),
1153            Some(CacheStatus::Cached),
1154            None
1155        ],
1156        "All user messages prior to anchor should be marked as cached."
1157    );
1158
1159    buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1160    text_thread.update(cx, |text_thread, cx| {
1161        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1162    });
1163    assert_eq!(
1164        messages_cache(&text_thread, cx)
1165            .iter()
1166            .map(|(_, cache)| cache
1167                .as_ref()
1168                .map_or(None, |cache| Some(cache.status.clone())))
1169            .collect::<Vec<Option<CacheStatus>>>(),
1170        vec![
1171            Some(CacheStatus::Cached),
1172            Some(CacheStatus::Cached),
1173            Some(CacheStatus::Pending),
1174            None
1175        ],
1176        "Modifying a message should invalidate it's cache but leave previous messages."
1177    );
1178    buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1179    text_thread.update(cx, |text_thread, cx| {
1180        text_thread.mark_cache_anchors(cache_configuration, false, cx)
1181    });
1182    assert_eq!(
1183        messages_cache(&text_thread, cx)
1184            .iter()
1185            .map(|(_, cache)| cache
1186                .as_ref()
1187                .map_or(None, |cache| Some(cache.status.clone())))
1188            .collect::<Vec<Option<CacheStatus>>>(),
1189        vec![
1190            Some(CacheStatus::Pending),
1191            Some(CacheStatus::Pending),
1192            Some(CacheStatus::Pending),
1193            None
1194        ],
1195        "Modifying a message should invalidate all future messages."
1196    );
1197}
1198
1199#[gpui::test]
1200async fn test_summarization(cx: &mut TestAppContext) {
1201    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1202
1203    // Initial state should be pending
1204    text_thread.read_with(cx, |text_thread, _| {
1205        assert!(matches!(text_thread.summary(), TextThreadSummary::Pending));
1206        assert_eq!(
1207            text_thread.summary().or_default(),
1208            TextThreadSummary::DEFAULT
1209        );
1210    });
1211
1212    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1213        text_thread.message_anchors[0].clone()
1214    });
1215    text_thread.update(cx, |context, cx| {
1216        context
1217            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1218            .unwrap();
1219    });
1220
1221    // Send a message
1222    text_thread.update(cx, |text_thread, cx| {
1223        text_thread.assist(cx);
1224    });
1225
1226    simulate_successful_response(&fake_model, cx);
1227
1228    // Should start generating summary when there are >= 2 messages
1229    text_thread.read_with(cx, |text_thread, _| {
1230        assert!(!text_thread.summary().content().unwrap().done);
1231    });
1232
1233    cx.run_until_parked();
1234    fake_model.send_last_completion_stream_text_chunk("Brief");
1235    fake_model.send_last_completion_stream_text_chunk(" Introduction");
1236    fake_model.end_last_completion_stream();
1237    cx.run_until_parked();
1238
1239    // Summary should be set
1240    text_thread.read_with(cx, |text_thread, _| {
1241        assert_eq!(text_thread.summary().or_default(), "Brief Introduction");
1242    });
1243
1244    // We should be able to manually set a summary
1245    text_thread.update(cx, |text_thread, cx| {
1246        text_thread.set_custom_summary("Brief Intro".into(), cx);
1247    });
1248
1249    text_thread.read_with(cx, |text_thread, _| {
1250        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1251    });
1252}
1253
1254#[gpui::test]
1255async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
1256    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1257
1258    test_summarize_error(&fake_model, &text_thread, cx);
1259
1260    // Now we should be able to set a summary
1261    text_thread.update(cx, |text_thread, cx| {
1262        text_thread.set_custom_summary("Brief Intro".into(), cx);
1263    });
1264
1265    text_thread.read_with(cx, |text_thread, _| {
1266        assert_eq!(text_thread.summary().or_default(), "Brief Intro");
1267    });
1268}
1269
1270#[gpui::test]
1271async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
1272    let (text_thread, fake_model) = setup_context_editor_with_fake_model(cx);
1273
1274    test_summarize_error(&fake_model, &text_thread, cx);
1275
1276    // Sending another message should not trigger another summarize request
1277    text_thread.update(cx, |text_thread, cx| {
1278        text_thread.assist(cx);
1279    });
1280
1281    simulate_successful_response(&fake_model, cx);
1282
1283    text_thread.read_with(cx, |text_thread, _| {
1284        // State is still Error, not Generating
1285        assert!(matches!(text_thread.summary(), TextThreadSummary::Error));
1286    });
1287
1288    // But the summarize request can be invoked manually
1289    text_thread.update(cx, |text_thread, cx| {
1290        text_thread.summarize(true, cx);
1291    });
1292
1293    text_thread.read_with(cx, |text_thread, _| {
1294        assert!(!text_thread.summary().content().unwrap().done);
1295    });
1296
1297    cx.run_until_parked();
1298    fake_model.send_last_completion_stream_text_chunk("A successful summary");
1299    fake_model.end_last_completion_stream();
1300    cx.run_until_parked();
1301
1302    text_thread.read_with(cx, |text_thread, _| {
1303        assert_eq!(text_thread.summary().or_default(), "A successful summary");
1304    });
1305}
1306
1307fn test_summarize_error(
1308    model: &Arc<FakeLanguageModel>,
1309    text_thread: &Entity<TextThread>,
1310    cx: &mut TestAppContext,
1311) {
1312    let message_1 = text_thread.read_with(cx, |text_thread, _cx| {
1313        text_thread.message_anchors[0].clone()
1314    });
1315    text_thread.update(cx, |text_thread, cx| {
1316        text_thread
1317            .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1318            .unwrap();
1319    });
1320
1321    // Send a message
1322    text_thread.update(cx, |text_thread, cx| {
1323        text_thread.assist(cx);
1324    });
1325
1326    simulate_successful_response(model, cx);
1327
1328    text_thread.read_with(cx, |text_thread, _| {
1329        assert!(!text_thread.summary().content().unwrap().done);
1330    });
1331
1332    // Simulate summary request ending
1333    cx.run_until_parked();
1334    model.end_last_completion_stream();
1335    cx.run_until_parked();
1336
1337    // State is set to Error and default message
1338    text_thread.read_with(cx, |text_thread, _| {
1339        assert_eq!(*text_thread.summary(), TextThreadSummary::Error);
1340        assert_eq!(
1341            text_thread.summary().or_default(),
1342            TextThreadSummary::DEFAULT
1343        );
1344    });
1345}
1346
1347fn setup_context_editor_with_fake_model(
1348    cx: &mut TestAppContext,
1349) -> (Entity<TextThread>, Arc<FakeLanguageModel>) {
1350    let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1351
1352    let fake_provider = Arc::new(FakeLanguageModelProvider::default());
1353    let fake_model = Arc::new(fake_provider.test_model());
1354
1355    cx.update(|cx| {
1356        init_test(cx);
1357        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1358            let configured_model = ConfiguredModel {
1359                provider: fake_provider.clone(),
1360                model: fake_model.clone(),
1361            };
1362            registry.set_default_model(Some(configured_model.clone()), cx);
1363            registry.set_thread_summary_model(Some(configured_model), cx);
1364        })
1365    });
1366
1367    let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1368    let context = cx.new(|cx| {
1369        TextThread::local(
1370            registry,
1371            None,
1372            None,
1373            prompt_builder.clone(),
1374            Arc::new(SlashCommandWorkingSet::default()),
1375            cx,
1376        )
1377    });
1378
1379    (context, fake_model)
1380}
1381
1382fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
1383    cx.run_until_parked();
1384    fake_model.send_last_completion_stream_text_chunk("Assistant response");
1385    fake_model.end_last_completion_stream();
1386    cx.run_until_parked();
1387}
1388
1389fn messages(context: &Entity<TextThread>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1390    context
1391        .read(cx)
1392        .messages(cx)
1393        .map(|message| (message.id, message.role, message.offset_range))
1394        .collect()
1395}
1396
1397fn messages_cache(
1398    context: &Entity<TextThread>,
1399    cx: &App,
1400) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1401    context
1402        .read(cx)
1403        .messages(cx)
1404        .map(|message| (message.id, message.cache))
1405        .collect()
1406}
1407
1408fn init_test(cx: &mut App) {
1409    let settings_store = SettingsStore::test(cx);
1410    prompt_store::init(cx);
1411    LanguageModelRegistry::test(cx);
1412    cx.set_global(settings_store);
1413}
1414
1415#[derive(Clone)]
1416struct FakeSlashCommand(String);
1417
1418impl SlashCommand for FakeSlashCommand {
1419    fn name(&self) -> String {
1420        self.0.clone()
1421    }
1422
1423    fn description(&self) -> String {
1424        format!("Fake slash command: {}", self.0)
1425    }
1426
1427    fn menu_text(&self) -> String {
1428        format!("Run fake command: {}", self.0)
1429    }
1430
1431    fn complete_argument(
1432        self: Arc<Self>,
1433        _arguments: &[String],
1434        _cancel: Arc<AtomicBool>,
1435        _workspace: Option<WeakEntity<Workspace>>,
1436        _window: &mut Window,
1437        _cx: &mut App,
1438    ) -> Task<Result<Vec<ArgumentCompletion>>> {
1439        Task::ready(Ok(vec![]))
1440    }
1441
1442    fn requires_argument(&self) -> bool {
1443        false
1444    }
1445
1446    fn run(
1447        self: Arc<Self>,
1448        _arguments: &[String],
1449        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1450        _context_buffer: BufferSnapshot,
1451        _workspace: WeakEntity<Workspace>,
1452        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1453        _window: &mut Window,
1454        _cx: &mut App,
1455    ) -> Task<SlashCommandResult> {
1456        Task::ready(Ok(SlashCommandOutput {
1457            text: format!("Executed fake command: {}", self.0),
1458            sections: vec![],
1459            run_commands_in_text: false,
1460        }
1461        .into_event_stream()))
1462    }
1463}