1use super::{AssistantEdit, MessageCacheMetadata};
2use crate::{
3 assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
4 Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
5};
6use anyhow::Result;
7use assistant_slash_command::{
8 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
9 SlashCommandRegistry, SlashCommandResult,
10};
11use collections::HashSet;
12use fs::FakeFs;
13use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
14use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
15use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
16use parking_lot::Mutex;
17use pretty_assertions::assert_eq;
18use project::Project;
19use rand::prelude::*;
20use serde_json::json;
21use settings::SettingsStore;
22use std::{
23 cell::RefCell,
24 env,
25 ops::Range,
26 path::Path,
27 rc::Rc,
28 sync::{atomic::AtomicBool, Arc},
29};
30use text::{network::Network, OffsetRangeExt as _, ReplicaId};
31use ui::{Context as _, WindowContext};
32use unindent::Unindent;
33use util::{
34 test::{generate_marked_text, marked_text_ranges},
35 RandomCharIter,
36};
37use workspace::Workspace;
38
39#[gpui::test]
40fn test_inserting_and_removing_messages(cx: &mut AppContext) {
41 let settings_store = SettingsStore::test(cx);
42 LanguageModelRegistry::test(cx);
43 cx.set_global(settings_store);
44 assistant_panel::init(cx);
45 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
46 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
47 let context =
48 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
49 let buffer = context.read(cx).buffer.clone();
50
51 let message_1 = context.read(cx).message_anchors[0].clone();
52 assert_eq!(
53 messages(&context, cx),
54 vec![(message_1.id, Role::User, 0..0)]
55 );
56
57 let message_2 = context.update(cx, |context, cx| {
58 context
59 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
60 .unwrap()
61 });
62 assert_eq!(
63 messages(&context, cx),
64 vec![
65 (message_1.id, Role::User, 0..1),
66 (message_2.id, Role::Assistant, 1..1)
67 ]
68 );
69
70 buffer.update(cx, |buffer, cx| {
71 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
72 });
73 assert_eq!(
74 messages(&context, cx),
75 vec![
76 (message_1.id, Role::User, 0..2),
77 (message_2.id, Role::Assistant, 2..3)
78 ]
79 );
80
81 let message_3 = context.update(cx, |context, cx| {
82 context
83 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
84 .unwrap()
85 });
86 assert_eq!(
87 messages(&context, cx),
88 vec![
89 (message_1.id, Role::User, 0..2),
90 (message_2.id, Role::Assistant, 2..4),
91 (message_3.id, Role::User, 4..4)
92 ]
93 );
94
95 let message_4 = context.update(cx, |context, cx| {
96 context
97 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
98 .unwrap()
99 });
100 assert_eq!(
101 messages(&context, cx),
102 vec![
103 (message_1.id, Role::User, 0..2),
104 (message_2.id, Role::Assistant, 2..4),
105 (message_4.id, Role::User, 4..5),
106 (message_3.id, Role::User, 5..5),
107 ]
108 );
109
110 buffer.update(cx, |buffer, cx| {
111 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
112 });
113 assert_eq!(
114 messages(&context, cx),
115 vec![
116 (message_1.id, Role::User, 0..2),
117 (message_2.id, Role::Assistant, 2..4),
118 (message_4.id, Role::User, 4..6),
119 (message_3.id, Role::User, 6..7),
120 ]
121 );
122
123 // Deleting across message boundaries merges the messages.
124 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
125 assert_eq!(
126 messages(&context, cx),
127 vec![
128 (message_1.id, Role::User, 0..3),
129 (message_3.id, Role::User, 3..4),
130 ]
131 );
132
133 // Undoing the deletion should also undo the merge.
134 buffer.update(cx, |buffer, cx| buffer.undo(cx));
135 assert_eq!(
136 messages(&context, cx),
137 vec![
138 (message_1.id, Role::User, 0..2),
139 (message_2.id, Role::Assistant, 2..4),
140 (message_4.id, Role::User, 4..6),
141 (message_3.id, Role::User, 6..7),
142 ]
143 );
144
145 // Redoing the deletion should also redo the merge.
146 buffer.update(cx, |buffer, cx| buffer.redo(cx));
147 assert_eq!(
148 messages(&context, cx),
149 vec![
150 (message_1.id, Role::User, 0..3),
151 (message_3.id, Role::User, 3..4),
152 ]
153 );
154
155 // Ensure we can still insert after a merged message.
156 let message_5 = context.update(cx, |context, cx| {
157 context
158 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
159 .unwrap()
160 });
161 assert_eq!(
162 messages(&context, cx),
163 vec![
164 (message_1.id, Role::User, 0..3),
165 (message_5.id, Role::System, 3..4),
166 (message_3.id, Role::User, 4..5)
167 ]
168 );
169}
170
171#[gpui::test]
172fn test_message_splitting(cx: &mut AppContext) {
173 let settings_store = SettingsStore::test(cx);
174 cx.set_global(settings_store);
175 LanguageModelRegistry::test(cx);
176 assistant_panel::init(cx);
177 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
178
179 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
180 let context =
181 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
182 let buffer = context.read(cx).buffer.clone();
183
184 let message_1 = context.read(cx).message_anchors[0].clone();
185 assert_eq!(
186 messages(&context, cx),
187 vec![(message_1.id, Role::User, 0..0)]
188 );
189
190 buffer.update(cx, |buffer, cx| {
191 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
192 });
193
194 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
195 let message_2 = message_2.unwrap();
196
197 // We recycle newlines in the middle of a split message
198 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
199 assert_eq!(
200 messages(&context, cx),
201 vec![
202 (message_1.id, Role::User, 0..4),
203 (message_2.id, Role::User, 4..16),
204 ]
205 );
206
207 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
208 let message_3 = message_3.unwrap();
209
210 // We don't recycle newlines at the end of a split message
211 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
212 assert_eq!(
213 messages(&context, cx),
214 vec![
215 (message_1.id, Role::User, 0..4),
216 (message_3.id, Role::User, 4..5),
217 (message_2.id, Role::User, 5..17),
218 ]
219 );
220
221 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
222 let message_4 = message_4.unwrap();
223 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
224 assert_eq!(
225 messages(&context, cx),
226 vec![
227 (message_1.id, Role::User, 0..4),
228 (message_3.id, Role::User, 4..5),
229 (message_2.id, Role::User, 5..9),
230 (message_4.id, Role::User, 9..17),
231 ]
232 );
233
234 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
235 let message_5 = message_5.unwrap();
236 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
237 assert_eq!(
238 messages(&context, cx),
239 vec![
240 (message_1.id, Role::User, 0..4),
241 (message_3.id, Role::User, 4..5),
242 (message_2.id, Role::User, 5..9),
243 (message_4.id, Role::User, 9..10),
244 (message_5.id, Role::User, 10..18),
245 ]
246 );
247
248 let (message_6, message_7) =
249 context.update(cx, |context, cx| context.split_message(14..16, cx));
250 let message_6 = message_6.unwrap();
251 let message_7 = message_7.unwrap();
252 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
253 assert_eq!(
254 messages(&context, cx),
255 vec![
256 (message_1.id, Role::User, 0..4),
257 (message_3.id, Role::User, 4..5),
258 (message_2.id, Role::User, 5..9),
259 (message_4.id, Role::User, 9..10),
260 (message_5.id, Role::User, 10..14),
261 (message_6.id, Role::User, 14..17),
262 (message_7.id, Role::User, 17..19),
263 ]
264 );
265}
266
267#[gpui::test]
268fn test_messages_for_offsets(cx: &mut AppContext) {
269 let settings_store = SettingsStore::test(cx);
270 LanguageModelRegistry::test(cx);
271 cx.set_global(settings_store);
272 assistant_panel::init(cx);
273 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
274 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
275 let context =
276 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
277 let buffer = context.read(cx).buffer.clone();
278
279 let message_1 = context.read(cx).message_anchors[0].clone();
280 assert_eq!(
281 messages(&context, cx),
282 vec![(message_1.id, Role::User, 0..0)]
283 );
284
285 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
286 let message_2 = context
287 .update(cx, |context, cx| {
288 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
289 })
290 .unwrap();
291 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
292
293 let message_3 = context
294 .update(cx, |context, cx| {
295 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
296 })
297 .unwrap();
298 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
299
300 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
301 assert_eq!(
302 messages(&context, cx),
303 vec![
304 (message_1.id, Role::User, 0..4),
305 (message_2.id, Role::User, 4..8),
306 (message_3.id, Role::User, 8..11)
307 ]
308 );
309
310 assert_eq!(
311 message_ids_for_offsets(&context, &[0, 4, 9], cx),
312 [message_1.id, message_2.id, message_3.id]
313 );
314 assert_eq!(
315 message_ids_for_offsets(&context, &[0, 1, 11], cx),
316 [message_1.id, message_3.id]
317 );
318
319 let message_4 = context
320 .update(cx, |context, cx| {
321 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
322 })
323 .unwrap();
324 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
325 assert_eq!(
326 messages(&context, cx),
327 vec![
328 (message_1.id, Role::User, 0..4),
329 (message_2.id, Role::User, 4..8),
330 (message_3.id, Role::User, 8..12),
331 (message_4.id, Role::User, 12..12)
332 ]
333 );
334 assert_eq!(
335 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
336 [message_1.id, message_2.id, message_3.id, message_4.id]
337 );
338
339 fn message_ids_for_offsets(
340 context: &Model<Context>,
341 offsets: &[usize],
342 cx: &AppContext,
343 ) -> Vec<MessageId> {
344 context
345 .read(cx)
346 .messages_for_offsets(offsets.iter().copied(), cx)
347 .into_iter()
348 .map(|message| message.id)
349 .collect()
350 }
351}
352
353#[gpui::test]
354async fn test_slash_commands(cx: &mut TestAppContext) {
355 let settings_store = cx.update(SettingsStore::test);
356 cx.set_global(settings_store);
357 cx.update(LanguageModelRegistry::test);
358 cx.update(Project::init_settings);
359 cx.update(assistant_panel::init);
360 let fs = FakeFs::new(cx.background_executor.clone());
361
362 fs.insert_tree(
363 "/test",
364 json!({
365 "src": {
366 "lib.rs": "fn one() -> usize { 1 }",
367 "main.rs": "
368 use crate::one;
369 fn main() { one(); }
370 ".unindent(),
371 }
372 }),
373 )
374 .await;
375
376 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
377 slash_command_registry.register_command(file_command::FileSlashCommand, false);
378
379 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
380 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
381 let context =
382 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
383
384 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
385 context.update(cx, |_, cx| {
386 cx.subscribe(&context, {
387 let ranges = output_ranges.clone();
388 move |_, _, event, _| match event {
389 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
390 for range in removed {
391 ranges.borrow_mut().remove(range);
392 }
393 for command in updated {
394 ranges.borrow_mut().insert(command.source_range.clone());
395 }
396 }
397 _ => {}
398 }
399 })
400 .detach();
401 });
402
403 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
404
405 // Insert a slash command
406 buffer.update(cx, |buffer, cx| {
407 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
408 });
409 assert_text_and_output_ranges(
410 &buffer,
411 &output_ranges.borrow(),
412 "
413 «/file src/lib.rs»
414 "
415 .unindent()
416 .trim_end(),
417 cx,
418 );
419
420 // Edit the argument of the slash command.
421 buffer.update(cx, |buffer, cx| {
422 let edit_offset = buffer.text().find("lib.rs").unwrap();
423 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
424 });
425 assert_text_and_output_ranges(
426 &buffer,
427 &output_ranges.borrow(),
428 "
429 «/file src/main.rs»
430 "
431 .unindent()
432 .trim_end(),
433 cx,
434 );
435
436 // Edit the name of the slash command, using one that doesn't exist.
437 buffer.update(cx, |buffer, cx| {
438 let edit_offset = buffer.text().find("/file").unwrap();
439 buffer.edit(
440 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
441 None,
442 cx,
443 );
444 });
445 assert_text_and_output_ranges(
446 &buffer,
447 &output_ranges.borrow(),
448 "
449 /unknown src/main.rs
450 "
451 .unindent()
452 .trim_end(),
453 cx,
454 );
455
456 #[track_caller]
457 fn assert_text_and_output_ranges(
458 buffer: &Model<Buffer>,
459 ranges: &HashSet<Range<language::Anchor>>,
460 expected_marked_text: &str,
461 cx: &mut TestAppContext,
462 ) {
463 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
464 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
465 let mut ranges = ranges
466 .iter()
467 .map(|range| range.to_offset(buffer))
468 .collect::<Vec<_>>();
469 ranges.sort_by_key(|a| a.start);
470 (buffer.text(), ranges)
471 });
472
473 assert_eq!(actual_text, expected_text);
474 assert_eq!(actual_ranges, expected_ranges);
475 }
476}
477
478#[gpui::test]
479async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
480 cx.update(prompt_library::init);
481 let mut settings_store = cx.update(SettingsStore::test);
482 cx.update(|cx| {
483 settings_store
484 .set_user_settings(
485 r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
486 cx,
487 )
488 .unwrap()
489 });
490 cx.set_global(settings_store);
491 cx.update(language::init);
492 cx.update(Project::init_settings);
493 let fs = FakeFs::new(cx.executor());
494 let project = Project::test(fs, [Path::new("/root")], cx).await;
495 cx.update(LanguageModelRegistry::test);
496
497 cx.update(assistant_panel::init);
498 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
499
500 // Create a new context
501 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
502 let context = cx.new_model(|cx| {
503 Context::local(
504 registry.clone(),
505 Some(project),
506 None,
507 prompt_builder.clone(),
508 cx,
509 )
510 });
511
512 // Insert an assistant message to simulate a response.
513 let assistant_message_id = context.update(cx, |context, cx| {
514 let user_message_id = context.messages(cx).next().unwrap().id;
515 context
516 .insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
517 .unwrap()
518 .id
519 });
520
521 // No edit tags
522 edit(
523 &context,
524 "
525
526 «one
527 two
528 »",
529 cx,
530 );
531 expect_patches(
532 &context,
533 "
534
535 one
536 two
537 ",
538 &[],
539 cx,
540 );
541
542 // Partial edit step tag is added
543 edit(
544 &context,
545 "
546
547 one
548 two
549 «
550 <patch»",
551 cx,
552 );
553 expect_patches(
554 &context,
555 "
556
557 one
558 two
559
560 <patch",
561 &[],
562 cx,
563 );
564
565 // The rest of the step tag is added. The unclosed
566 // step is treated as incomplete.
567 edit(
568 &context,
569 "
570
571 one
572 two
573
574 <patch«>
575 <edit>»",
576 cx,
577 );
578 expect_patches(
579 &context,
580 "
581
582 one
583 two
584
585 «<patch>
586 <edit>»",
587 &[&[]],
588 cx,
589 );
590
591 // The full patch is added
592 edit(
593 &context,
594 "
595
596 one
597 two
598
599 <patch>
600 <edit>«
601 <description>add a `two` function</description>
602 <path>src/lib.rs</path>
603 <operation>insert_after</operation>
604 <old_text>fn one</old_text>
605 <new_text>
606 fn two() {}
607 </new_text>
608 </edit>
609 </patch>
610
611 also,»",
612 cx,
613 );
614 expect_patches(
615 &context,
616 "
617
618 one
619 two
620
621 «<patch>
622 <edit>
623 <description>add a `two` function</description>
624 <path>src/lib.rs</path>
625 <operation>insert_after</operation>
626 <old_text>fn one</old_text>
627 <new_text>
628 fn two() {}
629 </new_text>
630 </edit>
631 </patch>
632 »
633 also,",
634 &[&[AssistantEdit {
635 path: "src/lib.rs".into(),
636 kind: AssistantEditKind::InsertAfter {
637 old_text: "fn one".into(),
638 new_text: "fn two() {}".into(),
639 description: Some("add a `two` function".into()),
640 },
641 }]],
642 cx,
643 );
644
645 // The step is manually edited.
646 edit(
647 &context,
648 "
649
650 one
651 two
652
653 <patch>
654 <edit>
655 <description>add a `two` function</description>
656 <path>src/lib.rs</path>
657 <operation>insert_after</operation>
658 <old_text>«fn zero»</old_text>
659 <new_text>
660 fn two() {}
661 </new_text>
662 </edit>
663 </patch>
664
665 also,",
666 cx,
667 );
668 expect_patches(
669 &context,
670 "
671
672 one
673 two
674
675 «<patch>
676 <edit>
677 <description>add a `two` function</description>
678 <path>src/lib.rs</path>
679 <operation>insert_after</operation>
680 <old_text>fn zero</old_text>
681 <new_text>
682 fn two() {}
683 </new_text>
684 </edit>
685 </patch>
686 »
687 also,",
688 &[&[AssistantEdit {
689 path: "src/lib.rs".into(),
690 kind: AssistantEditKind::InsertAfter {
691 old_text: "fn zero".into(),
692 new_text: "fn two() {}".into(),
693 description: Some("add a `two` function".into()),
694 },
695 }]],
696 cx,
697 );
698
699 // When setting the message role to User, the steps are cleared.
700 context.update(cx, |context, cx| {
701 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
702 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
703 });
704 expect_patches(
705 &context,
706 "
707
708 one
709 two
710
711 <patch>
712 <edit>
713 <description>add a `two` function</description>
714 <path>src/lib.rs</path>
715 <operation>insert_after</operation>
716 <old_text>fn zero</old_text>
717 <new_text>
718 fn two() {}
719 </new_text>
720 </edit>
721 </patch>
722
723 also,",
724 &[],
725 cx,
726 );
727
728 // When setting the message role back to Assistant, the steps are reparsed.
729 context.update(cx, |context, cx| {
730 context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
731 });
732 expect_patches(
733 &context,
734 "
735
736 one
737 two
738
739 «<patch>
740 <edit>
741 <description>add a `two` function</description>
742 <path>src/lib.rs</path>
743 <operation>insert_after</operation>
744 <old_text>fn zero</old_text>
745 <new_text>
746 fn two() {}
747 </new_text>
748 </edit>
749 </patch>
750 »
751 also,",
752 &[&[AssistantEdit {
753 path: "src/lib.rs".into(),
754 kind: AssistantEditKind::InsertAfter {
755 old_text: "fn zero".into(),
756 new_text: "fn two() {}".into(),
757 description: Some("add a `two` function".into()),
758 },
759 }]],
760 cx,
761 );
762
763 // Ensure steps are re-parsed when deserializing.
764 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
765 let deserialized_context = cx.new_model(|cx| {
766 Context::deserialize(
767 serialized_context,
768 Default::default(),
769 registry.clone(),
770 prompt_builder.clone(),
771 None,
772 None,
773 cx,
774 )
775 });
776 expect_patches(
777 &deserialized_context,
778 "
779
780 one
781 two
782
783 «<patch>
784 <edit>
785 <description>add a `two` function</description>
786 <path>src/lib.rs</path>
787 <operation>insert_after</operation>
788 <old_text>fn zero</old_text>
789 <new_text>
790 fn two() {}
791 </new_text>
792 </edit>
793 </patch>
794 »
795 also,",
796 &[&[AssistantEdit {
797 path: "src/lib.rs".into(),
798 kind: AssistantEditKind::InsertAfter {
799 old_text: "fn zero".into(),
800 new_text: "fn two() {}".into(),
801 description: Some("add a `two` function".into()),
802 },
803 }]],
804 cx,
805 );
806
807 fn edit(context: &Model<Context>, new_text_marked_with_edits: &str, cx: &mut TestAppContext) {
808 context.update(cx, |context, cx| {
809 context.buffer.update(cx, |buffer, cx| {
810 buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
811 });
812 });
813 cx.executor().run_until_parked();
814 }
815
816 #[track_caller]
817 fn expect_patches(
818 context: &Model<Context>,
819 expected_marked_text: &str,
820 expected_suggestions: &[&[AssistantEdit]],
821 cx: &mut TestAppContext,
822 ) {
823 let expected_marked_text = expected_marked_text.unindent();
824 let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
825
826 let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
827 context.buffer.read_with(cx, |buffer, _| {
828 let ranges = context
829 .patches
830 .iter()
831 .map(|entry| entry.range.to_offset(buffer))
832 .collect::<Vec<_>>();
833 (
834 buffer.text(),
835 ranges,
836 context
837 .patches
838 .iter()
839 .map(|step| step.edits.clone())
840 .collect::<Vec<_>>(),
841 )
842 })
843 });
844
845 assert_eq!(buffer_text, expected_text);
846
847 let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
848 assert_eq!(actual_marked_text, expected_marked_text);
849
850 assert_eq!(
851 patches
852 .iter()
853 .map(|patch| {
854 patch
855 .iter()
856 .map(|edit| {
857 let edit = edit.as_ref().unwrap();
858 AssistantEdit {
859 path: edit.path.clone(),
860 kind: edit.kind.clone(),
861 }
862 })
863 .collect::<Vec<_>>()
864 })
865 .collect::<Vec<_>>(),
866 expected_suggestions
867 );
868 }
869}
870
871#[gpui::test]
872async fn test_serialization(cx: &mut TestAppContext) {
873 let settings_store = cx.update(SettingsStore::test);
874 cx.set_global(settings_store);
875 cx.update(LanguageModelRegistry::test);
876 cx.update(assistant_panel::init);
877 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
878 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
879 let context =
880 cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
881 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
882 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
883 let message_1 = context.update(cx, |context, cx| {
884 context
885 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
886 .unwrap()
887 });
888 let message_2 = context.update(cx, |context, cx| {
889 context
890 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
891 .unwrap()
892 });
893 buffer.update(cx, |buffer, cx| {
894 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
895 buffer.finalize_last_transaction();
896 });
897 let _message_3 = context.update(cx, |context, cx| {
898 context
899 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
900 .unwrap()
901 });
902 buffer.update(cx, |buffer, cx| buffer.undo(cx));
903 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
904 assert_eq!(
905 cx.read(|cx| messages(&context, cx)),
906 [
907 (message_0, Role::User, 0..2),
908 (message_1.id, Role::Assistant, 2..6),
909 (message_2.id, Role::System, 6..6),
910 ]
911 );
912
913 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
914 let deserialized_context = cx.new_model(|cx| {
915 Context::deserialize(
916 serialized_context,
917 Default::default(),
918 registry.clone(),
919 prompt_builder.clone(),
920 None,
921 None,
922 cx,
923 )
924 });
925 let deserialized_buffer =
926 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
927 assert_eq!(
928 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
929 "a\nb\nc\n"
930 );
931 assert_eq!(
932 cx.read(|cx| messages(&deserialized_context, cx)),
933 [
934 (message_0, Role::User, 0..2),
935 (message_1.id, Role::Assistant, 2..6),
936 (message_2.id, Role::System, 6..6),
937 ]
938 );
939}
940
941#[gpui::test(iterations = 100)]
942async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
943 let min_peers = env::var("MIN_PEERS")
944 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
945 .unwrap_or(2);
946 let max_peers = env::var("MAX_PEERS")
947 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
948 .unwrap_or(5);
949 let operations = env::var("OPERATIONS")
950 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
951 .unwrap_or(50);
952
953 let settings_store = cx.update(SettingsStore::test);
954 cx.set_global(settings_store);
955 cx.update(LanguageModelRegistry::test);
956
957 cx.update(assistant_panel::init);
958 let slash_commands = cx.update(SlashCommandRegistry::default_global);
959 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
960 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
961 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
962
963 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
964 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
965 let mut contexts = Vec::new();
966
967 let num_peers = rng.gen_range(min_peers..=max_peers);
968 let context_id = ContextId::new();
969 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
970 for i in 0..num_peers {
971 let context = cx.new_model(|cx| {
972 Context::new(
973 context_id.clone(),
974 i as ReplicaId,
975 language::Capability::ReadWrite,
976 registry.clone(),
977 prompt_builder.clone(),
978 None,
979 None,
980 cx,
981 )
982 });
983
984 cx.update(|cx| {
985 cx.subscribe(&context, {
986 let network = network.clone();
987 move |_, event, _| {
988 if let ContextEvent::Operation(op) = event {
989 network
990 .lock()
991 .broadcast(i as ReplicaId, vec![op.to_proto()]);
992 }
993 }
994 })
995 .detach();
996 });
997
998 contexts.push(context);
999 network.lock().add_peer(i as ReplicaId);
1000 }
1001
1002 let mut mutation_count = operations;
1003
1004 while mutation_count > 0
1005 || !network.lock().is_idle()
1006 || network.lock().contains_disconnected_peers()
1007 {
1008 let context_index = rng.gen_range(0..contexts.len());
1009 let context = &contexts[context_index];
1010
1011 match rng.gen_range(0..100) {
1012 0..=29 if mutation_count > 0 => {
1013 log::info!("Context {}: edit buffer", context_index);
1014 context.update(cx, |context, cx| {
1015 context
1016 .buffer
1017 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
1018 });
1019 mutation_count -= 1;
1020 }
1021 30..=44 if mutation_count > 0 => {
1022 context.update(cx, |context, cx| {
1023 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
1024 log::info!("Context {}: split message at {:?}", context_index, range);
1025 context.split_message(range, cx);
1026 });
1027 mutation_count -= 1;
1028 }
1029 45..=59 if mutation_count > 0 => {
1030 context.update(cx, |context, cx| {
1031 if let Some(message) = context.messages(cx).choose(&mut rng) {
1032 let role = *[Role::User, Role::Assistant, Role::System]
1033 .choose(&mut rng)
1034 .unwrap();
1035 log::info!(
1036 "Context {}: insert message after {:?} with {:?}",
1037 context_index,
1038 message.id,
1039 role
1040 );
1041 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
1042 }
1043 });
1044 mutation_count -= 1;
1045 }
1046 60..=74 if mutation_count > 0 => {
1047 context.update(cx, |context, cx| {
1048 let command_text = "/".to_string()
1049 + slash_commands
1050 .command_names()
1051 .choose(&mut rng)
1052 .unwrap()
1053 .clone()
1054 .as_ref();
1055
1056 let command_range = context.buffer.update(cx, |buffer, cx| {
1057 let offset = buffer.random_byte_range(0, &mut rng).start;
1058 buffer.edit(
1059 [(offset..offset, format!("\n{}\n", command_text))],
1060 None,
1061 cx,
1062 );
1063 offset + 1..offset + 1 + command_text.len()
1064 });
1065
1066 let output_len = rng.gen_range(1..=10);
1067 let output_text = RandomCharIter::new(&mut rng)
1068 .filter(|c| *c != '\r')
1069 .take(output_len)
1070 .collect::<String>();
1071
1072 let num_sections = rng.gen_range(0..=3);
1073 let mut sections = Vec::with_capacity(num_sections);
1074 for _ in 0..num_sections {
1075 let section_start = rng.gen_range(0..output_len);
1076 let section_end = rng.gen_range(section_start..=output_len);
1077 sections.push(SlashCommandOutputSection {
1078 range: section_start..section_end,
1079 icon: ui::IconName::Ai,
1080 label: "section".into(),
1081 metadata: None,
1082 });
1083 }
1084
1085 log::info!(
1086 "Context {}: insert slash command output at {:?} with {:?}",
1087 context_index,
1088 command_range,
1089 sections
1090 );
1091
1092 let command_range = context.buffer.read(cx).anchor_after(command_range.start)
1093 ..context.buffer.read(cx).anchor_after(command_range.end);
1094 context.insert_command_output(
1095 command_range,
1096 Task::ready(Ok(SlashCommandOutput {
1097 text: output_text,
1098 sections,
1099 run_commands_in_text: false,
1100 }
1101 .to_event_stream())),
1102 true,
1103 false,
1104 cx,
1105 );
1106 });
1107 cx.run_until_parked();
1108 mutation_count -= 1;
1109 }
1110 75..=84 if mutation_count > 0 => {
1111 context.update(cx, |context, cx| {
1112 if let Some(message) = context.messages(cx).choose(&mut rng) {
1113 let new_status = match rng.gen_range(0..3) {
1114 0 => MessageStatus::Done,
1115 1 => MessageStatus::Pending,
1116 _ => MessageStatus::Error(SharedString::from("Random error")),
1117 };
1118 log::info!(
1119 "Context {}: update message {:?} status to {:?}",
1120 context_index,
1121 message.id,
1122 new_status
1123 );
1124 context.update_metadata(message.id, cx, |metadata| {
1125 metadata.status = new_status;
1126 });
1127 }
1128 });
1129 mutation_count -= 1;
1130 }
1131 _ => {
1132 let replica_id = context_index as ReplicaId;
1133 if network.lock().is_disconnected(replica_id) {
1134 network.lock().reconnect_peer(replica_id, 0);
1135
1136 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
1137 let host_context = &contexts[0].read(cx);
1138 let guest_context = context.read(cx);
1139 (
1140 guest_context.serialize_ops(&host_context.version(cx), cx),
1141 host_context.serialize_ops(&guest_context.version(cx), cx),
1142 )
1143 });
1144 let ops_to_send = ops_to_send.await;
1145 let ops_to_receive = ops_to_receive
1146 .await
1147 .into_iter()
1148 .map(ContextOperation::from_proto)
1149 .collect::<Result<Vec<_>>>()
1150 .unwrap();
1151 log::info!(
1152 "Context {}: reconnecting. Sent {} operations, received {} operations",
1153 context_index,
1154 ops_to_send.len(),
1155 ops_to_receive.len()
1156 );
1157
1158 network.lock().broadcast(replica_id, ops_to_send);
1159 context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx));
1160 } else if rng.gen_bool(0.1) && replica_id != 0 {
1161 log::info!("Context {}: disconnecting", context_index);
1162 network.lock().disconnect_peer(replica_id);
1163 } else if network.lock().has_unreceived(replica_id) {
1164 log::info!("Context {}: applying operations", context_index);
1165 let ops = network.lock().receive(replica_id);
1166 let ops = ops
1167 .into_iter()
1168 .map(ContextOperation::from_proto)
1169 .collect::<Result<Vec<_>>>()
1170 .unwrap();
1171 context.update(cx, |context, cx| context.apply_ops(ops, cx));
1172 }
1173 }
1174 }
1175 }
1176
1177 cx.read(|cx| {
1178 let first_context = contexts[0].read(cx);
1179 for context in &contexts[1..] {
1180 let context = context.read(cx);
1181 assert!(context.pending_ops.is_empty());
1182 assert_eq!(
1183 context.buffer.read(cx).text(),
1184 first_context.buffer.read(cx).text(),
1185 "Context {} text != Context 0 text",
1186 context.buffer.read(cx).replica_id()
1187 );
1188 assert_eq!(
1189 context.message_anchors,
1190 first_context.message_anchors,
1191 "Context {} messages != Context 0 messages",
1192 context.buffer.read(cx).replica_id()
1193 );
1194 assert_eq!(
1195 context.messages_metadata,
1196 first_context.messages_metadata,
1197 "Context {} message metadata != Context 0 message metadata",
1198 context.buffer.read(cx).replica_id()
1199 );
1200 assert_eq!(
1201 context.slash_command_output_sections,
1202 first_context.slash_command_output_sections,
1203 "Context {} slash command output sections != Context 0 slash command output sections",
1204 context.buffer.read(cx).replica_id()
1205 );
1206 }
1207 });
1208}
1209
1210#[gpui::test]
1211fn test_mark_cache_anchors(cx: &mut AppContext) {
1212 let settings_store = SettingsStore::test(cx);
1213 LanguageModelRegistry::test(cx);
1214 cx.set_global(settings_store);
1215 assistant_panel::init(cx);
1216 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
1217 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1218 let context =
1219 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
1220 let buffer = context.read(cx).buffer.clone();
1221
1222 // Create a test cache configuration
1223 let cache_configuration = &Some(LanguageModelCacheConfiguration {
1224 max_cache_anchors: 3,
1225 should_speculate: true,
1226 min_total_token: 10,
1227 });
1228
1229 let message_1 = context.read(cx).message_anchors[0].clone();
1230
1231 context.update(cx, |context, cx| {
1232 context.mark_cache_anchors(cache_configuration, false, cx)
1233 });
1234
1235 assert_eq!(
1236 messages_cache(&context, cx)
1237 .iter()
1238 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1239 .count(),
1240 0,
1241 "Empty messages should not have any cache anchors."
1242 );
1243
1244 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1245 let message_2 = context
1246 .update(cx, |context, cx| {
1247 context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
1248 })
1249 .unwrap();
1250
1251 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
1252 let message_3 = context
1253 .update(cx, |context, cx| {
1254 context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
1255 })
1256 .unwrap();
1257 buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
1258
1259 context.update(cx, |context, cx| {
1260 context.mark_cache_anchors(cache_configuration, false, cx)
1261 });
1262 assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
1263 assert_eq!(
1264 messages_cache(&context, cx)
1265 .iter()
1266 .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1267 .count(),
1268 0,
1269 "Messages should not be marked for cache before going over the token minimum."
1270 );
1271 context.update(cx, |context, _| {
1272 context.token_count = Some(20);
1273 });
1274
1275 context.update(cx, |context, cx| {
1276 context.mark_cache_anchors(cache_configuration, true, cx)
1277 });
1278 assert_eq!(
1279 messages_cache(&context, cx)
1280 .iter()
1281 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1282 .collect::<Vec<bool>>(),
1283 vec![true, true, false],
1284 "Last message should not be an anchor on speculative request."
1285 );
1286
1287 context
1288 .update(cx, |context, cx| {
1289 context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
1290 })
1291 .unwrap();
1292
1293 context.update(cx, |context, cx| {
1294 context.mark_cache_anchors(cache_configuration, false, cx)
1295 });
1296 assert_eq!(
1297 messages_cache(&context, cx)
1298 .iter()
1299 .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
1300 .collect::<Vec<bool>>(),
1301 vec![false, true, true, false],
1302 "Most recent message should also be cached if not a speculative request."
1303 );
1304 context.update(cx, |context, cx| {
1305 context.update_cache_status_for_completion(cx)
1306 });
1307 assert_eq!(
1308 messages_cache(&context, cx)
1309 .iter()
1310 .map(|(_, cache)| cache
1311 .as_ref()
1312 .map_or(None, |cache| Some(cache.status.clone())))
1313 .collect::<Vec<Option<CacheStatus>>>(),
1314 vec![
1315 Some(CacheStatus::Cached),
1316 Some(CacheStatus::Cached),
1317 Some(CacheStatus::Cached),
1318 None
1319 ],
1320 "All user messages prior to anchor should be marked as cached."
1321 );
1322
1323 buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
1324 context.update(cx, |context, cx| {
1325 context.mark_cache_anchors(cache_configuration, false, cx)
1326 });
1327 assert_eq!(
1328 messages_cache(&context, cx)
1329 .iter()
1330 .map(|(_, cache)| cache
1331 .as_ref()
1332 .map_or(None, |cache| Some(cache.status.clone())))
1333 .collect::<Vec<Option<CacheStatus>>>(),
1334 vec![
1335 Some(CacheStatus::Cached),
1336 Some(CacheStatus::Cached),
1337 Some(CacheStatus::Pending),
1338 None
1339 ],
1340 "Modifying a message should invalidate it's cache but leave previous messages."
1341 );
1342 buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
1343 context.update(cx, |context, cx| {
1344 context.mark_cache_anchors(cache_configuration, false, cx)
1345 });
1346 assert_eq!(
1347 messages_cache(&context, cx)
1348 .iter()
1349 .map(|(_, cache)| cache
1350 .as_ref()
1351 .map_or(None, |cache| Some(cache.status.clone())))
1352 .collect::<Vec<Option<CacheStatus>>>(),
1353 vec![
1354 Some(CacheStatus::Pending),
1355 Some(CacheStatus::Pending),
1356 Some(CacheStatus::Pending),
1357 None
1358 ],
1359 "Modifying a message should invalidate all future messages."
1360 );
1361}
1362
1363fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
1364 context
1365 .read(cx)
1366 .messages(cx)
1367 .map(|message| (message.id, message.role, message.offset_range))
1368 .collect()
1369}
1370
1371fn messages_cache(
1372 context: &Model<Context>,
1373 cx: &AppContext,
1374) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
1375 context
1376 .read(cx)
1377 .messages(cx)
1378 .map(|message| (message.id, message.cache.clone()))
1379 .collect()
1380}
1381
1382#[derive(Clone)]
1383struct FakeSlashCommand(String);
1384
1385impl SlashCommand for FakeSlashCommand {
1386 fn name(&self) -> String {
1387 self.0.clone()
1388 }
1389
1390 fn description(&self) -> String {
1391 format!("Fake slash command: {}", self.0)
1392 }
1393
1394 fn menu_text(&self) -> String {
1395 format!("Run fake command: {}", self.0)
1396 }
1397
1398 fn complete_argument(
1399 self: Arc<Self>,
1400 _arguments: &[String],
1401 _cancel: Arc<AtomicBool>,
1402 _workspace: Option<WeakView<Workspace>>,
1403 _cx: &mut WindowContext,
1404 ) -> Task<Result<Vec<ArgumentCompletion>>> {
1405 Task::ready(Ok(vec![]))
1406 }
1407
1408 fn requires_argument(&self) -> bool {
1409 false
1410 }
1411
1412 fn run(
1413 self: Arc<Self>,
1414 _arguments: &[String],
1415 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
1416 _context_buffer: BufferSnapshot,
1417 _workspace: WeakView<Workspace>,
1418 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
1419 _cx: &mut WindowContext,
1420 ) -> Task<SlashCommandResult> {
1421 Task::ready(Ok(SlashCommandOutput {
1422 text: format!("Executed fake command: {}", self.0),
1423 sections: vec![],
1424 run_commands_in_text: false,
1425 }
1426 .to_event_stream()))
1427 }
1428}