1use crate::{
2 AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
3 ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
4};
5use anyhow::Result;
6use assistant_slash_command::{
7 ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
8 SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult, SlashCommandWorkingSet,
9};
10use assistant_slash_commands::FileSlashCommand;
11use collections::{HashMap, HashSet};
12use fs::FakeFs;
13use futures::{
14 channel::mpsc,
15 stream::{self, StreamExt},
16};
17use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
18use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
19use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
20use parking_lot::Mutex;
21use pretty_assertions::assert_eq;
22use project::Project;
23use prompt_store::PromptBuilder;
24use rand::prelude::*;
25use serde_json::json;
26use settings::SettingsStore;
27use std::{
28 cell::RefCell,
29 env,
30 ops::Range,
31 path::Path,
32 rc::Rc,
33 sync::{Arc, atomic::AtomicBool},
34};
35use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
36use ui::{IconName, Window};
37use unindent::Unindent;
38use util::{
39 RandomCharIter,
40 test::{generate_marked_text, marked_text_ranges},
41};
42use workspace::Workspace;
43
44#[gpui::test]
45fn test_inserting_and_removing_messages(cx: &mut App) {
46 init_test(cx);
47
48 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
49 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
50 let context = cx.new(|cx| {
51 AssistantContext::local(
52 registry,
53 None,
54 None,
55 prompt_builder.clone(),
56 Arc::new(SlashCommandWorkingSet::default()),
57 cx,
58 )
59 });
60 let buffer = context.read(cx).buffer.clone();
61
62 let message_1 = context.read(cx).message_anchors[0].clone();
63 assert_eq!(
64 messages(&context, cx),
65 vec![(message_1.id, Role::User, 0..0)]
66 );
67
68 let message_2 = context.update(cx, |context, cx| {
69 context
70 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
71 .unwrap()
72 });
73 assert_eq!(
74 messages(&context, cx),
75 vec![
76 (message_1.id, Role::User, 0..1),
77 (message_2.id, Role::Assistant, 1..1)
78 ]
79 );
80
81 buffer.update(cx, |buffer, cx| {
82 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
83 });
84 assert_eq!(
85 messages(&context, cx),
86 vec![
87 (message_1.id, Role::User, 0..2),
88 (message_2.id, Role::Assistant, 2..3)
89 ]
90 );
91
92 let message_3 = context.update(cx, |context, cx| {
93 context
94 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
95 .unwrap()
96 });
97 assert_eq!(
98 messages(&context, cx),
99 vec![
100 (message_1.id, Role::User, 0..2),
101 (message_2.id, Role::Assistant, 2..4),
102 (message_3.id, Role::User, 4..4)
103 ]
104 );
105
106 let message_4 = context.update(cx, |context, cx| {
107 context
108 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
109 .unwrap()
110 });
111 assert_eq!(
112 messages(&context, cx),
113 vec![
114 (message_1.id, Role::User, 0..2),
115 (message_2.id, Role::Assistant, 2..4),
116 (message_4.id, Role::User, 4..5),
117 (message_3.id, Role::User, 5..5),
118 ]
119 );
120
121 buffer.update(cx, |buffer, cx| {
122 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
123 });
124 assert_eq!(
125 messages(&context, cx),
126 vec![
127 (message_1.id, Role::User, 0..2),
128 (message_2.id, Role::Assistant, 2..4),
129 (message_4.id, Role::User, 4..6),
130 (message_3.id, Role::User, 6..7),
131 ]
132 );
133
134 // Deleting across message boundaries merges the messages.
135 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
136 assert_eq!(
137 messages(&context, cx),
138 vec![
139 (message_1.id, Role::User, 0..3),
140 (message_3.id, Role::User, 3..4),
141 ]
142 );
143
144 // Undoing the deletion should also undo the merge.
145 buffer.update(cx, |buffer, cx| buffer.undo(cx));
146 assert_eq!(
147 messages(&context, cx),
148 vec![
149 (message_1.id, Role::User, 0..2),
150 (message_2.id, Role::Assistant, 2..4),
151 (message_4.id, Role::User, 4..6),
152 (message_3.id, Role::User, 6..7),
153 ]
154 );
155
156 // Redoing the deletion should also redo the merge.
157 buffer.update(cx, |buffer, cx| buffer.redo(cx));
158 assert_eq!(
159 messages(&context, cx),
160 vec![
161 (message_1.id, Role::User, 0..3),
162 (message_3.id, Role::User, 3..4),
163 ]
164 );
165
166 // Ensure we can still insert after a merged message.
167 let message_5 = context.update(cx, |context, cx| {
168 context
169 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
170 .unwrap()
171 });
172 assert_eq!(
173 messages(&context, cx),
174 vec![
175 (message_1.id, Role::User, 0..3),
176 (message_5.id, Role::System, 3..4),
177 (message_3.id, Role::User, 4..5)
178 ]
179 );
180}
181
182#[gpui::test]
183fn test_message_splitting(cx: &mut App) {
184 init_test(cx);
185
186 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
187
188 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
189 let context = cx.new(|cx| {
190 AssistantContext::local(
191 registry.clone(),
192 None,
193 None,
194 prompt_builder.clone(),
195 Arc::new(SlashCommandWorkingSet::default()),
196 cx,
197 )
198 });
199 let buffer = context.read(cx).buffer.clone();
200
201 let message_1 = context.read(cx).message_anchors[0].clone();
202 assert_eq!(
203 messages(&context, cx),
204 vec![(message_1.id, Role::User, 0..0)]
205 );
206
207 buffer.update(cx, |buffer, cx| {
208 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
209 });
210
211 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
212 let message_2 = message_2.unwrap();
213
214 // We recycle newlines in the middle of a split message
215 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
216 assert_eq!(
217 messages(&context, cx),
218 vec![
219 (message_1.id, Role::User, 0..4),
220 (message_2.id, Role::User, 4..16),
221 ]
222 );
223
224 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
225 let message_3 = message_3.unwrap();
226
227 // We don't recycle newlines at the end of a split message
228 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
229 assert_eq!(
230 messages(&context, cx),
231 vec![
232 (message_1.id, Role::User, 0..4),
233 (message_3.id, Role::User, 4..5),
234 (message_2.id, Role::User, 5..17),
235 ]
236 );
237
238 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
239 let message_4 = message_4.unwrap();
240 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
241 assert_eq!(
242 messages(&context, cx),
243 vec![
244 (message_1.id, Role::User, 0..4),
245 (message_3.id, Role::User, 4..5),
246 (message_2.id, Role::User, 5..9),
247 (message_4.id, Role::User, 9..17),
248 ]
249 );
250
251 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
252 let message_5 = message_5.unwrap();
253 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
254 assert_eq!(
255 messages(&context, cx),
256 vec![
257 (message_1.id, Role::User, 0..4),
258 (message_3.id, Role::User, 4..5),
259 (message_2.id, Role::User, 5..9),
260 (message_4.id, Role::User, 9..10),
261 (message_5.id, Role::User, 10..18),
262 ]
263 );
264
265 let (message_6, message_7) =
266 context.update(cx, |context, cx| context.split_message(14..16, cx));
267 let message_6 = message_6.unwrap();
268 let message_7 = message_7.unwrap();
269 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
270 assert_eq!(
271 messages(&context, cx),
272 vec![
273 (message_1.id, Role::User, 0..4),
274 (message_3.id, Role::User, 4..5),
275 (message_2.id, Role::User, 5..9),
276 (message_4.id, Role::User, 9..10),
277 (message_5.id, Role::User, 10..14),
278 (message_6.id, Role::User, 14..17),
279 (message_7.id, Role::User, 17..19),
280 ]
281 );
282}
283
284#[gpui::test]
285fn test_messages_for_offsets(cx: &mut App) {
286 init_test(cx);
287
288 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
289 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
290 let context = cx.new(|cx| {
291 AssistantContext::local(
292 registry,
293 None,
294 None,
295 prompt_builder.clone(),
296 Arc::new(SlashCommandWorkingSet::default()),
297 cx,
298 )
299 });
300 let buffer = context.read(cx).buffer.clone();
301
302 let message_1 = context.read(cx).message_anchors[0].clone();
303 assert_eq!(
304 messages(&context, cx),
305 vec![(message_1.id, Role::User, 0..0)]
306 );
307
308 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
309 let message_2 = context
310 .update(cx, |context, cx| {
311 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
312 })
313 .unwrap();
314 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
315
316 let message_3 = context
317 .update(cx, |context, cx| {
318 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
319 })
320 .unwrap();
321 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
322
323 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
324 assert_eq!(
325 messages(&context, cx),
326 vec![
327 (message_1.id, Role::User, 0..4),
328 (message_2.id, Role::User, 4..8),
329 (message_3.id, Role::User, 8..11)
330 ]
331 );
332
333 assert_eq!(
334 message_ids_for_offsets(&context, &[0, 4, 9], cx),
335 [message_1.id, message_2.id, message_3.id]
336 );
337 assert_eq!(
338 message_ids_for_offsets(&context, &[0, 1, 11], cx),
339 [message_1.id, message_3.id]
340 );
341
342 let message_4 = context
343 .update(cx, |context, cx| {
344 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
345 })
346 .unwrap();
347 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
348 assert_eq!(
349 messages(&context, cx),
350 vec![
351 (message_1.id, Role::User, 0..4),
352 (message_2.id, Role::User, 4..8),
353 (message_3.id, Role::User, 8..12),
354 (message_4.id, Role::User, 12..12)
355 ]
356 );
357 assert_eq!(
358 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
359 [message_1.id, message_2.id, message_3.id, message_4.id]
360 );
361
362 fn message_ids_for_offsets(
363 context: &Entity<AssistantContext>,
364 offsets: &[usize],
365 cx: &App,
366 ) -> Vec<MessageId> {
367 context
368 .read(cx)
369 .messages_for_offsets(offsets.iter().copied(), cx)
370 .into_iter()
371 .map(|message| message.id)
372 .collect()
373 }
374}
375
376#[gpui::test]
377async fn test_slash_commands(cx: &mut TestAppContext) {
378 cx.update(init_test);
379
380 let fs = FakeFs::new(cx.background_executor.clone());
381
382 fs.insert_tree(
383 "/test",
384 json!({
385 "src": {
386 "lib.rs": "fn one() -> usize { 1 }",
387 "main.rs": "
388 use crate::one;
389 fn main() { one(); }
390 ".unindent(),
391 }
392 }),
393 )
394 .await;
395
396 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
397 slash_command_registry.register_command(FileSlashCommand, false);
398
399 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
400 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
401 let context = cx.new(|cx| {
402 AssistantContext::local(
403 registry.clone(),
404 None,
405 None,
406 prompt_builder.clone(),
407 Arc::new(SlashCommandWorkingSet::default()),
408 cx,
409 )
410 });
411
412 #[derive(Default)]
413 struct ContextRanges {
414 parsed_commands: HashSet<Range<language::Anchor>>,
415 command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
416 output_sections: HashSet<Range<language::Anchor>>,
417 }
418
419 let context_ranges = Rc::new(RefCell::new(ContextRanges::default()));
420 context.update(cx, |_, cx| {
421 cx.subscribe(&context, {
422 let context_ranges = context_ranges.clone();
423 move |context, _, event, _| {
424 let mut context_ranges = context_ranges.borrow_mut();
425 match event {
426 ContextEvent::InvokedSlashCommandChanged { command_id } => {
427 let command = context.invoked_slash_command(command_id).unwrap();
428 context_ranges
429 .command_outputs
430 .insert(*command_id, command.range.clone());
431 }
432 ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
433 for range in removed {
434 context_ranges.parsed_commands.remove(range);
435 }
436 for command in updated {
437 context_ranges
438 .parsed_commands
439 .insert(command.source_range.clone());
440 }
441 }
442 ContextEvent::SlashCommandOutputSectionAdded { section } => {
443 context_ranges.output_sections.insert(section.range.clone());
444 }
445 _ => {}
446 }
447 }
448 })
449 .detach();
450 });
451
452 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
453
454 // Insert a slash command
455 buffer.update(cx, |buffer, cx| {
456 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
457 });
458 assert_text_and_context_ranges(
459 &buffer,
460 &context_ranges,
461 &"
462 «/file src/lib.rs»"
463 .unindent(),
464 cx,
465 );
466
467 // Edit the argument of the slash command.
468 buffer.update(cx, |buffer, cx| {
469 let edit_offset = buffer.text().find("lib.rs").unwrap();
470 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
471 });
472 assert_text_and_context_ranges(
473 &buffer,
474 &context_ranges,
475 &"
476 «/file src/main.rs»"
477 .unindent(),
478 cx,
479 );
480
481 // Edit the name of the slash command, using one that doesn't exist.
482 buffer.update(cx, |buffer, cx| {
483 let edit_offset = buffer.text().find("/file").unwrap();
484 buffer.edit(
485 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
486 None,
487 cx,
488 );
489 });
490 assert_text_and_context_ranges(
491 &buffer,
492 &context_ranges,
493 &"
494 /unknown src/main.rs"
495 .unindent(),
496 cx,
497 );
498
499 // Undoing the insertion of an non-existent slash command resorts the previous one.
500 buffer.update(cx, |buffer, cx| buffer.undo(cx));
501 assert_text_and_context_ranges(
502 &buffer,
503 &context_ranges,
504 &"
505 «/file src/main.rs»"
506 .unindent(),
507 cx,
508 );
509
510 let (command_output_tx, command_output_rx) = mpsc::unbounded();
511 context.update(cx, |context, cx| {
512 let command_source_range = context.parsed_slash_commands[0].source_range.clone();
513 context.insert_command_output(
514 command_source_range,
515 "file",
516 Task::ready(Ok(command_output_rx.boxed())),
517 true,
518 cx,
519 );
520 });
521 assert_text_and_context_ranges(
522 &buffer,
523 &context_ranges,
524 &"
525 ⟦«/file src/main.rs»
526 …⟧
527 "
528 .unindent(),
529 cx,
530 );
531
532 command_output_tx
533 .unbounded_send(Ok(SlashCommandEvent::StartSection {
534 icon: IconName::Ai,
535 label: "src/main.rs".into(),
536 metadata: None,
537 }))
538 .unwrap();
539 command_output_tx
540 .unbounded_send(Ok(SlashCommandEvent::Content("src/main.rs".into())))
541 .unwrap();
542 cx.run_until_parked();
543 assert_text_and_context_ranges(
544 &buffer,
545 &context_ranges,
546 &"
547 ⟦«/file src/main.rs»
548 src/main.rs…⟧
549 "
550 .unindent(),
551 cx,
552 );
553
554 command_output_tx
555 .unbounded_send(Ok(SlashCommandEvent::Content("\nfn main() {}".into())))
556 .unwrap();
557 cx.run_until_parked();
558 assert_text_and_context_ranges(
559 &buffer,
560 &context_ranges,
561 &"
562 ⟦«/file src/main.rs»
563 src/main.rs
564 fn main() {}…⟧
565 "
566 .unindent(),
567 cx,
568 );
569
570 command_output_tx
571 .unbounded_send(Ok(SlashCommandEvent::EndSection))
572 .unwrap();
573 cx.run_until_parked();
574 assert_text_and_context_ranges(
575 &buffer,
576 &context_ranges,
577 &"
578 ⟦«/file src/main.rs»
579 ⟪src/main.rs
580 fn main() {}⟫…⟧
581 "
582 .unindent(),
583 cx,
584 );
585
586 drop(command_output_tx);
587 cx.run_until_parked();
588 assert_text_and_context_ranges(
589 &buffer,
590 &context_ranges,
591 &"
592 ⟦⟪src/main.rs
593 fn main() {}⟫⟧
594 "
595 .unindent(),
596 cx,
597 );
598
599 #[track_caller]
600 fn assert_text_and_context_ranges(
601 buffer: &Entity<Buffer>,
602 ranges: &RefCell<ContextRanges>,
603 expected_marked_text: &str,
604 cx: &mut TestAppContext,
605 ) {
606 let mut actual_marked_text = String::new();
607 buffer.update(cx, |buffer, _| {
608 struct Endpoint {
609 offset: usize,
610 marker: char,
611 }
612
613 let ranges = ranges.borrow();
614 let mut endpoints = Vec::new();
615 for range in ranges.command_outputs.values() {
616 endpoints.push(Endpoint {
617 offset: range.start.to_offset(buffer),
618 marker: '⟦',
619 });
620 }
621 for range in ranges.parsed_commands.iter() {
622 endpoints.push(Endpoint {
623 offset: range.start.to_offset(buffer),
624 marker: '«',
625 });
626 }
627 for range in ranges.output_sections.iter() {
628 endpoints.push(Endpoint {
629 offset: range.start.to_offset(buffer),
630 marker: '⟪',
631 });
632 }
633
634 for range in ranges.output_sections.iter() {
635 endpoints.push(Endpoint {
636 offset: range.end.to_offset(buffer),
637 marker: '⟫',
638 });
639 }
640 for range in ranges.parsed_commands.iter() {
641 endpoints.push(Endpoint {
642 offset: range.end.to_offset(buffer),
643 marker: '»',
644 });
645 }
646 for range in ranges.command_outputs.values() {
647 endpoints.push(Endpoint {
648 offset: range.end.to_offset(buffer),
649 marker: '⟧',
650 });
651 }
652
653 endpoints.sort_by_key(|endpoint| endpoint.offset);
654 let mut offset = 0;
655 for endpoint in endpoints {
656 actual_marked_text.extend(buffer.text_for_range(offset..endpoint.offset));
657 actual_marked_text.push(endpoint.marker);
658 offset = endpoint.offset;
659 }
660 actual_marked_text.extend(buffer.text_for_range(offset..buffer.len()));
661 });
662
663 assert_eq!(actual_marked_text, expected_marked_text);
664 }
665}
666
667#[gpui::test]
668async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
669 cx.update(|cx| {
670 init_test(cx);
671 cx.update_global(|settings_store: &mut SettingsStore, cx| {
672 settings_store
673 .set_user_settings(
674 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
675 cx,
676 )
677 .unwrap()
678 })
679 });
680 let fs = FakeFs::new(cx.executor());
681 let project = Project::test(fs, [Path::new("/root")], cx).await;
682
683 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
684
685 // Create a new context
686 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
687 let context = cx.new(|cx| {
688 AssistantContext::local(
689 registry.clone(),
690 Some(project),
691 None,
692 prompt_builder.clone(),
693 Arc::new(SlashCommandWorkingSet::default()),
694 cx,
695 )
696 });
697
698 // Insert an assistant message to simulate a response.
699 let assistant_message_id = context.update(cx, |context, cx| {
700 let user_message_id = context.messages(cx).next().unwrap().id;
701 context
702 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
703 .unwrap()
704 .id
705 });
706
707 // No edit tags
708 edit(
709 &context,
710 "
711
712 «one
713 two
714 »",
715 cx,
716 );
717 expect_patches(
718 &context,
719 "
720
721 one
722 two
723 ",
724 &[],
725 cx,
726 );
727
728 // Partial edit step tag is added
729 edit(
730 &context,
731 "
732
733 one
734 two
735 «
736 <patch»",
737 cx,
738 );
739 expect_patches(
740 &context,
741 "
742
743 one
744 two
745
746 <patch",
747 &[],
748 cx,
749 );
750
751 // The rest of the step tag is added. The unclosed
752 // step is treated as incomplete.
753 edit(
754 &context,
755 "
756
757 one
758 two
759
760 <patch«>
761 <edit>»",
762 cx,
763 );
764 expect_patches(
765 &context,
766 "
767
768 one
769 two
770
771 «<patch>
772 <edit>»",
773 &[&[]],
774 cx,
775 );
776
777 // The full patch is added
778 edit(
779 &context,
780 "
781
782 one
783 two
784
785 <patch>
786 <edit>«
787 <description>add a `two` function</description>
788 <path>src/lib.rs</path>
789 <operation>insert_after</operation>
790 <old_text>fn one</old_text>
791 <new_text>
792 fn two() {}
793 </new_text>
794 </edit>
795 </patch>
796
797 also,»",
798 cx,
799 );
800 expect_patches(
801 &context,
802 "
803
804 one
805 two
806
807 «<patch>
808 <edit>
809 <description>add a `two` function</description>
810 <path>src/lib.rs</path>
811 <operation>insert_after</operation>
812 <old_text>fn one</old_text>
813 <new_text>
814 fn two() {}
815 </new_text>
816 </edit>
817 </patch>
818 »
819 also,",
820 &[&[AssistantEdit {
821 path: "src/lib.rs".into(),
822 kind: AssistantEditKind::InsertAfter {
823 old_text: "fn one".into(),
824 new_text: "fn two() {}".into(),
825 description: Some("add a `two` function".into()),
826 },
827 }]],
828 cx,
829 );
830
831 // The step is manually edited.
832 edit(
833 &context,
834 "
835
836 one
837 two
838
839 <patch>
840 <edit>
841 <description>add a `two` function</description>
842 <path>src/lib.rs</path>
843 <operation>insert_after</operation>
844 <old_text>«fn zero»</old_text>
845 <new_text>
846 fn two() {}
847 </new_text>
848 </edit>
849 </patch>
850
851 also,",
852 cx,
853 );
854 expect_patches(
855 &context,
856 "
857
858 one
859 two
860
861 «<patch>
862 <edit>
863 <description>add a `two` function</description>
864 <path>src/lib.rs</path>
865 <operation>insert_after</operation>
866 <old_text>fn zero</old_text>
867 <new_text>
868 fn two() {}
869 </new_text>
870 </edit>
871 </patch>
872 »
873 also,",
874 &[&[AssistantEdit {
875 path: "src/lib.rs".into(),
876 kind: AssistantEditKind::InsertAfter {
877 old_text: "fn zero".into(),
878 new_text: "fn two() {}".into(),
879 description: Some("add a `two` function".into()),
880 },
881 }]],
882 cx,
883 );
884
885 // When setting the message role to User, the steps are cleared.
886 context.update(cx, |context, cx| {
887 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
888 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
889 });
890 expect_patches(
891 &context,
892 "
893
894 one
895 two
896
897 <patch>
898 <edit>
899 <description>add a `two` function</description>
900 <path>src/lib.rs</path>
901 <operation>insert_after</operation>
902 <old_text>fn zero</old_text>
903 <new_text>
904 fn two() {}
905 </new_text>
906 </edit>
907 </patch>
908
909 also,",
910 &[],
911 cx,
912 );
913
914 // When setting the message role back to Assistant, the steps are reparsed.
915 context.update(cx, |context, cx| {
916 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
917 });
918 expect_patches(
919 &context,
920 "
921
922 one
923 two
924
925 «<patch>
926 <edit>
927 <description>add a `two` function</description>
928 <path>src/lib.rs</path>
929 <operation>insert_after</operation>
930 <old_text>fn zero</old_text>
931 <new_text>
932 fn two() {}
933 </new_text>
934 </edit>
935 </patch>
936 »
937 also,",
938 &[&[AssistantEdit {
939 path: "src/lib.rs".into(),
940 kind: AssistantEditKind::InsertAfter {
941 old_text: "fn zero".into(),
942 new_text: "fn two() {}".into(),
943 description: Some("add a `two` function".into()),
944 },
945 }]],
946 cx,
947 );
948
949 // Ensure steps are re-parsed when deserializing.
950 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
951 let deserialized_context = cx.new(|cx| {
952 AssistantContext::deserialize(
953 serialized_context,
954 Path::new("").into(),
955 registry.clone(),
956 prompt_builder.clone(),
957 Arc::new(SlashCommandWorkingSet::default()),
958 None,
959 None,
960 cx,
961 )
962 });
963 expect_patches(
964 &deserialized_context,
965 "
966
967 one
968 two
969
970 «<patch>
971 <edit>
972 <description>add a `two` function</description>
973 <path>src/lib.rs</path>
974 <operation>insert_after</operation>
975 <old_text>fn zero</old_text>
976 <new_text>
977 fn two() {}
978 </new_text>
979 </edit>
980 </patch>
981 »
982 also,",
983 &[&[AssistantEdit {
984 path: "src/lib.rs".into(),
985 kind: AssistantEditKind::InsertAfter {
986 old_text: "fn zero".into(),
987 new_text: "fn two() {}".into(),
988 description: Some("add a `two` function".into()),
989 },
990 }]],
991 cx,
992 );
993
994 fn edit(
995 context: &Entity<AssistantContext>,
996 new_text_marked_with_edits: &str,
997 cx: &mut TestAppContext,
998 ) {
999 context.update(cx, |context, cx| {
1000 context.buffer.update(cx, |buffer, cx| {
1001 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
1002 });
1003 });
1004 cx.executor().run_until_parked();
1005 }
1006
1007 #[track_caller]
1008 fn expect_patches(
1009 context: &Entity<AssistantContext>,
1010 expected_marked_text: &str,
1011 expected_suggestions: &[&[AssistantEdit]],
1012 cx: &mut TestAppContext,
1013 ) {
1014 let expected_marked_text = expected_marked_text.unindent();
1015 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
1016
1017 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
1018 context.buffer.read_with(cx, |buffer, _| {
1019 let ranges = context
1020 .patches
1021 .iter()
1022 .map(|entry| entry.range.to_offset(buffer))
1023 .collect::<Vec<_>>();
1024 (
1025 buffer.text(),
1026 ranges,
1027 context
1028 .patches
1029 .iter()
1030 .map(|step| step.edits.clone())
1031 .collect::<Vec<_>>(),
1032 )
1033 })
1034 });
1035
1036 assert_eq!(buffer_text, expected_text);
1037
1038 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
1039 assert_eq!(actual_marked_text, expected_marked_text);
1040
1041 assert_eq!(
1042 patches
1043 .iter()
1044 .map(|patch| {
1045 patch
1046 .iter()
1047 .map(|edit| {
1048 let edit = edit.as_ref().unwrap();
1049 AssistantEdit {
1050 path: edit.path.clone(),
1051 kind: edit.kind.clone(),
1052 }
1053 })
1054 .collect::<Vec<_>>()
1055 })
1056 .collect::<Vec<_>>(),
1057 expected_suggestions
1058 );
1059 }
1060}
1061
1062#[gpui::test]
1063async fn test_serialization(cx: &mut TestAppContext) {
1064 cx.update(init_test);
1065
1066 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
1067 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1068 let context = cx.new(|cx| {
1069 AssistantContext::local(
1070 registry.clone(),
1071 None,
1072 None,
1073 prompt_builder.clone(),
1074 Arc::new(SlashCommandWorkingSet::default()),
1075 cx,
1076 )
1077 });
1078 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
1079 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
1080 let message_1 = context.update(cx, |context, cx| {
1081 context
1082 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
1083 .unwrap()
1084 });
1085 let message_2 = context.update(cx, |context, cx| {
1086 context
1087 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1088 .unwrap()
1089 });
1090 buffer.update(cx, |buffer, cx| {
1091 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
1092 buffer.finalize_last_transaction();
1093 });
1094 let _message_3 = context.update(cx, |context, cx| {
1095 context
1096 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
1097 .unwrap()
1098 });
1099 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1100 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
1101 assert_eq!(
1102 cx.read(|cx| messages(&context, cx)),
1103 [
1104 (message_0, Role::User, 0..2),
1105 (message_1.id, Role::Assistant, 2..6),
1106 (message_2.id, Role::System, 6..6),
1107 ]
1108 );
1109
1110 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
1111 let deserialized_context = cx.new(|cx| {
1112 AssistantContext::deserialize(
1113 serialized_context,
1114 Path::new("").into(),
1115 registry.clone(),
1116 prompt_builder.clone(),
1117 Arc::new(SlashCommandWorkingSet::default()),
1118 None,
1119 None,
1120 cx,
1121 )
1122 });
1123 let deserialized_buffer =
1124 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
1125 assert_eq!(
1126 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
1127 "a\nb\nc\n"
1128 );
1129 assert_eq!(
1130 cx.read(|cx| messages(&deserialized_context, cx)),
1131 [
1132 (message_0, Role::User, 0..2),
1133 (message_1.id, Role::Assistant, 2..6),
1134 (message_2.id, Role::System, 6..6),
1135 ]
1136 );
1137}
1138
1139#[gpui::test(iterations = 100)]
1140async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
1141 cx.update(init_test);
1142
1143 let min_peers = env::var("MIN_PEERS")
1144 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
1145 .unwrap_or(2);
1146 let max_peers = env::var("MAX_PEERS")
1147 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
1148 .unwrap_or(5);
1149 let operations = env::var("OPERATIONS")
1150 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
1151 .unwrap_or(50);
1152
1153 let slash_commands = cx.update(SlashCommandRegistry::default_global);
1154 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
1155 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
1156 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
1157
1158 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
1159 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
1160 let mut contexts = Vec::new();
1161
1162 let num_peers = rng.gen_range(min_peers..=max_peers);
1163 let context_id = ContextId::new();
1164 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1165 for i in 0..num_peers {
1166 let context = cx.new(|cx| {
1167 AssistantContext::new(
1168 context_id.clone(),
1169 i as ReplicaId,
1170 language::Capability::ReadWrite,
1171 registry.clone(),
1172 prompt_builder.clone(),
1173 Arc::new(SlashCommandWorkingSet::default()),
1174 None,
1175 None,
1176 cx,
1177 )
1178 });
1179
1180 cx.update(|cx| {
1181 cx.subscribe(&context, {
1182 let network = network.clone();
1183 move |_, event, _| {
1184 if let ContextEvent::Operation(op) = event {
1185 network
1186 .lock()
1187 .broadcast(i as ReplicaId, vec![op.to_proto()]);
1188 }
1189 }
1190 })
1191 .detach();
1192 });
1193
1194 contexts.push(context);
1195 network.lock().add_peer(i as ReplicaId);
1196 }
1197
1198 let mut mutation_count = operations;
1199
1200 while mutation_count > 0
1201 || !network.lock().is_idle()
1202 || network.lock().contains_disconnected_peers()
1203 {
1204 let context_index = rng.gen_range(0..contexts.len());
1205 let context = &contexts[context_index];
1206
1207 match rng.gen_range(0..100) {
1208 0..=29 if mutation_count > 0 => {
1209 log::info!("Context {}: edit buffer", context_index);
1210 context.update(cx, |context, cx| {
1211 context
1212 .buffer
1213 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1214 });
1215 mutation_count -= 1;
1216 }
1217 30..=44 if mutation_count > 0 => {
1218 context.update(cx, |context, cx| {
1219 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1220 log::info!("Context {}: split message at {:?}", context_index, range);
1221 context.split_message(range, cx);
1222 });
1223 mutation_count -= 1;
1224 }
1225 45..=59 if mutation_count > 0 => {
1226 context.update(cx, |context, cx| {
1227 if let Some(message) = context.messages(cx).choose(&mut rng) {
1228 let role = *[Role::User, Role::Assistant, Role::System]
1229 .choose(&mut rng)
1230 .unwrap();
1231 log::info!(
1232 "Context {}: insert message after {:?} with {:?}",
1233 context_index,
1234 message.id,
1235 role
1236 );
1237 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1238 }
1239 });
1240 mutation_count -= 1;
1241 }
1242 60..=74 if mutation_count > 0 => {
1243 context.update(cx, |context, cx| {
1244 let command_text = "/".to_string()
1245 + slash_commands
1246 .command_names()
1247 .choose(&mut rng)
1248 .unwrap()
1249 .clone()
1250 .as_ref();
1251
1252 let command_range = context.buffer.update(cx, |buffer, cx| {
1253 let offset = buffer.random_byte_range(0, &mut rng).start;
1254 buffer.edit(
1255 [(offset..offset, format!("\n{}\n", command_text))],
1256 None,
1257 cx,
1258 );
1259 offset + 1..offset + 1 + command_text.len()
1260 });
1261
1262 let output_text = RandomCharIter::new(&mut rng)
1263 .filter(|c| *c != '\r')
1264 .take(10)
1265 .collect::<String>();
1266
1267 let mut events = vec![Ok(SlashCommandEvent::StartMessage {
1268 role: Role::User,
1269 merge_same_roles: true,
1270 })];
1271
1272 let num_sections = rng.gen_range(0..=3);
1273 let mut section_start = 0;
1274 for _ in 0..num_sections {
1275 let mut section_end = rng.gen_range(section_start..=output_text.len());
1276 while !output_text.is_char_boundary(section_end) {
1277 section_end += 1;
1278 }
1279 events.push(Ok(SlashCommandEvent::StartSection {
1280 icon: IconName::Ai,
1281 label: "section".into(),
1282 metadata: None,
1283 }));
1284 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1285 text: output_text[section_start..section_end].to_string(),
1286 run_commands_in_text: false,
1287 })));
1288 events.push(Ok(SlashCommandEvent::EndSection));
1289 section_start = section_end;
1290 }
1291
1292 if section_start < output_text.len() {
1293 events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
1294 text: output_text[section_start..].to_string(),
1295 run_commands_in_text: false,
1296 })));
1297 }
1298
1299 log::info!(
1300 "Context {}: insert slash command output at {:?} with {:?} events",
1301 context_index,
1302 command_range,
1303 events.len()
1304 );
1305
1306 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1307 ..context.buffer.read(cx).anchor_after(command_range.end);
1308 context.insert_command_output(
1309 command_range,
1310 "/command",
1311 Task::ready(Ok(stream::iter(events).boxed())),
1312 true,
1313 cx,
1314 );
1315 });
1316 cx.run_until_parked();
1317 mutation_count -= 1;
1318 }
1319 75..=84 if mutation_count > 0 => {
1320 context.update(cx, |context, cx| {
1321 if let Some(message) = context.messages(cx).choose(&mut rng) {
1322 let new_status = match rng.gen_range(0..3) {
1323 0 => MessageStatus::Done,
1324 1 => MessageStatus::Pending,
1325 _ => MessageStatus::Error(SharedString::from("Random error")),
1326 };
1327 log::info!(
1328 "Context {}: update message {:?} status to {:?}",
1329 context_index,
1330 message.id,
1331 new_status
1332 );
1333 context.update_metadata(message.id, cx, |metadata| {
1334 metadata.status = new_status;
1335 });
1336 }
1337 });
1338 mutation_count -= 1;
1339 }
1340 _ => {
1341 let replica_id = context_index as ReplicaId;
1342 if network.lock().is_disconnected(replica_id) {
1343 network.lock().reconnect_peer(replica_id, 0);
1344
1345 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1346 let host_context = &contexts[0].read(cx);
1347 let guest_context = context.read(cx);
1348 (
1349 guest_context.serialize_ops(&host_context.version(cx), cx),
1350 host_context.serialize_ops(&guest_context.version(cx), cx),
1351 )
1352 });
1353 let ops_to_send = ops_to_send.await;
1354 let ops_to_receive = ops_to_receive
1355 .await
1356 .into_iter()
1357 .map(ContextOperation::from_proto)
1358 .collect::<Result<Vec<_>>>()
1359 .unwrap();
1360 log::info!(
1361 "Context {}: reconnecting. Sent {} operations, received {} operations",
1362 context_index,
1363 ops_to_send.len(),
1364 ops_to_receive.len()
1365 );
1366
1367 network.lock().broadcast(replica_id, ops_to_send);
1368 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1369 } else if rng.gen_bool(0.1) && replica_id != 0 {
1370 log::info!("Context {}: disconnecting", context_index);
1371 network.lock().disconnect_peer(replica_id);
1372 } else if network.lock().has_unreceived(replica_id) {
1373 log::info!("Context {}: applying operations", context_index);
1374 let ops = network.lock().receive(replica_id);
1375 let ops = ops
1376 .into_iter()
1377 .map(ContextOperation::from_proto)
1378 .collect::<Result<Vec<_>>>()
1379 .unwrap();
1380 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1381 }
1382 }
1383 }
1384 }
1385
1386 cx.read(|cx| {
1387 let first_context = contexts[0].read(cx);
1388 for context in &contexts[1..] {
1389 let context = context.read(cx);
1390 assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
1391 assert_eq!(
1392 context.buffer.read(cx).text(),
1393 first_context.buffer.read(cx).text(),
1394 "Context {} text != Context 0 text",
1395 context.buffer.read(cx).replica_id()
1396 );
1397 assert_eq!(
1398 context.message_anchors,
1399 first_context.message_anchors,
1400 "Context {} messages != Context 0 messages",
1401 context.buffer.read(cx).replica_id()
1402 );
1403 assert_eq!(
1404 context.messages_metadata,
1405 first_context.messages_metadata,
1406 "Context {} message metadata != Context 0 message metadata",
1407 context.buffer.read(cx).replica_id()
1408 );
1409 assert_eq!(
1410 context.slash_command_output_sections,
1411 first_context.slash_command_output_sections,
1412 "Context {} slash command output sections != Context 0 slash command output sections",
1413 context.buffer.read(cx).replica_id()
1414 );
1415 }
1416 });
1417}
1418
1419#[gpui::test]
1420fn test_mark_cache_anchors(cx: &mut App) {
1421 init_test(cx);
1422
1423 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1424 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1425 let context = cx.new(|cx| {
1426 AssistantContext::local(
1427 registry,
1428 None,
1429 None,
1430 prompt_builder.clone(),
1431 Arc::new(SlashCommandWorkingSet::default()),
1432 cx,
1433 )
1434 });
1435 let buffer = context.read(cx).buffer.clone();
1436
1437 // Create a test cache configuration
1438 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1439 max_cache_anchors: 3,
1440 should_speculate: true,
1441 min_total_token: 10,
1442 });
1443
1444 let message_1 = context.read(cx).message_anchors[0].clone();
1445
1446 context.update(cx, |context, cx| {
1447 context.mark_cache_anchors(cache_configuration, false, cx)
1448 });
1449
1450 assert_eq!(
1451 messages_cache(&context, cx)
1452 .iter()
1453 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1454 .count(),
1455 0,
1456 "Empty messages should not have any cache anchors."
1457 );
1458
1459 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1460 let message_2 = context
1461 .update(cx, |context, cx| {
1462 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1463 })
1464 .unwrap();
1465
1466 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1467 let message_3 = context
1468 .update(cx, |context, cx| {
1469 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1470 })
1471 .unwrap();
1472 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1473
1474 context.update(cx, |context, cx| {
1475 context.mark_cache_anchors(cache_configuration, false, cx)
1476 });
1477 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1478 assert_eq!(
1479 messages_cache(&context, cx)
1480 .iter()
1481 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1482 .count(),
1483 0,
1484 "Messages should not be marked for cache before going over the token minimum."
1485 );
1486 context.update(cx, |context, _| {
1487 context.token_count = Some(20);
1488 });
1489
1490 context.update(cx, |context, cx| {
1491 context.mark_cache_anchors(cache_configuration, true, cx)
1492 });
1493 assert_eq!(
1494 messages_cache(&context, cx)
1495 .iter()
1496 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1497 .collect::<Vec<bool>>(),
1498 vec![true, true, false],
1499 "Last message should not be an anchor on speculative request."
1500 );
1501
1502 context
1503 .update(cx, |context, cx| {
1504 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1505 })
1506 .unwrap();
1507
1508 context.update(cx, |context, cx| {
1509 context.mark_cache_anchors(cache_configuration, false, cx)
1510 });
1511 assert_eq!(
1512 messages_cache(&context, cx)
1513 .iter()
1514 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1515 .collect::<Vec<bool>>(),
1516 vec![false, true, true, false],
1517 "Most recent message should also be cached if not a speculative request."
1518 );
1519 context.update(cx, |context, cx| {
1520 context.update_cache_status_for_completion(cx)
1521 });
1522 assert_eq!(
1523 messages_cache(&context, cx)
1524 .iter()
1525 .map(|(_, cache)| cache
1526 .as_ref()
1527 .map_or(None, |cache| Some(cache.status.clone())))
1528 .collect::<Vec<Option<CacheStatus>>>(),
1529 vec![
1530 Some(CacheStatus::Cached),
1531 Some(CacheStatus::Cached),
1532 Some(CacheStatus::Cached),
1533 None
1534 ],
1535 "All user messages prior to anchor should be marked as cached."
1536 );
1537
1538 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1539 context.update(cx, |context, cx| {
1540 context.mark_cache_anchors(cache_configuration, false, cx)
1541 });
1542 assert_eq!(
1543 messages_cache(&context, cx)
1544 .iter()
1545 .map(|(_, cache)| cache
1546 .as_ref()
1547 .map_or(None, |cache| Some(cache.status.clone())))
1548 .collect::<Vec<Option<CacheStatus>>>(),
1549 vec![
1550 Some(CacheStatus::Cached),
1551 Some(CacheStatus::Cached),
1552 Some(CacheStatus::Pending),
1553 None
1554 ],
1555 "Modifying a message should invalidate it's cache but leave previous messages."
1556 );
1557 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1558 context.update(cx, |context, cx| {
1559 context.mark_cache_anchors(cache_configuration, false, cx)
1560 });
1561 assert_eq!(
1562 messages_cache(&context, cx)
1563 .iter()
1564 .map(|(_, cache)| cache
1565 .as_ref()
1566 .map_or(None, |cache| Some(cache.status.clone())))
1567 .collect::<Vec<Option<CacheStatus>>>(),
1568 vec![
1569 Some(CacheStatus::Pending),
1570 Some(CacheStatus::Pending),
1571 Some(CacheStatus::Pending),
1572 None
1573 ],
1574 "Modifying a message should invalidate all future messages."
1575 );
1576}
1577
1578fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
1579 context
1580 .read(cx)
1581 .messages(cx)
1582 .map(|message| (message.id, message.role, message.offset_range))
1583 .collect()
1584}
1585
1586fn messages_cache(
1587 context: &Entity<AssistantContext>,
1588 cx: &App,
1589) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1590 context
1591 .read(cx)
1592 .messages(cx)
1593 .map(|message| (message.id, message.cache.clone()))
1594 .collect()
1595}
1596
1597fn init_test(cx: &mut App) {
1598 let settings_store = SettingsStore::test(cx);
1599 prompt_store::init(cx);
1600 LanguageModelRegistry::test(cx);
1601 cx.set_global(settings_store);
1602 language::init(cx);
1603 assistant_settings::init(cx);
1604 Project::init_settings(cx);
1605}
1606
1607#[derive(Clone)]
1608struct FakeSlashCommand(String);
1609
1610impl SlashCommand for FakeSlashCommand {
1611 fn name(&self) -> String {
1612 self.0.clone()
1613 }
1614
1615 fn description(&self) -> String {
1616 format!("Fake slash command: {}", self.0)
1617 }
1618
1619 fn menu_text(&self) -> String {
1620 format!("Run fake command: {}", self.0)
1621 }
1622
1623 fn complete_argument(
1624 self: Arc<Self>,
1625 _arguments: &[String],
1626 _cancel: Arc<AtomicBool>,
1627 _workspace: Option<WeakEntity<Workspace>>,
1628 _window: &mut Window,
1629 _cx: &mut App,
1630 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1631 Task::ready(Ok(vec![]))
1632 }
1633
1634 fn requires_argument(&self) -> bool {
1635 false
1636 }
1637
1638 fn run(
1639 self: Arc<Self>,
1640 _arguments: &[String],
1641 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1642 _context_buffer: BufferSnapshot,
1643 _workspace: WeakEntity<Workspace>,
1644 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1645 _window: &mut Window,
1646 _cx: &mut App,
1647 ) -> Task<SlashCommandResult> {
1648 Task::ready(Ok(SlashCommandOutput {
1649 text: format!("Executed fake command: {}", self.0),
1650 sections: vec![],
1651 run_commands_in_text: false,
1652 }
1653 .to_event_stream()))
1654 }
1655}